gaur3009 commited on
Commit
a96e5dd
·
verified ·
1 Parent(s): 6cd4678

Upload 10 files

Browse files
src/cyclegan_turbo.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoTokenizer, CLIPTextModel
7
+ from diffusers import AutoencoderKL, UNet2DConditionModel
8
+ from peft import LoraConfig
9
+ from peft.utils import get_peft_model_state_dict
10
+ p = "src/"
11
+ sys.path.append(p)
12
+ from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, download_url
13
+
14
+
15
+ class VAE_encode(nn.Module):
16
+ def __init__(self, vae, vae_b2a=None):
17
+ super(VAE_encode, self).__init__()
18
+ self.vae = vae
19
+ self.vae_b2a = vae_b2a
20
+
21
+ def forward(self, x, direction):
22
+ assert direction in ["a2b", "b2a"]
23
+ if direction == "a2b":
24
+ _vae = self.vae
25
+ else:
26
+ _vae = self.vae_b2a
27
+ return _vae.encode(x).latent_dist.sample() * _vae.config.scaling_factor
28
+
29
+
30
+ class VAE_decode(nn.Module):
31
+ def __init__(self, vae, vae_b2a=None):
32
+ super(VAE_decode, self).__init__()
33
+ self.vae = vae
34
+ self.vae_b2a = vae_b2a
35
+
36
+ def forward(self, x, direction):
37
+ assert direction in ["a2b", "b2a"]
38
+ if direction == "a2b":
39
+ _vae = self.vae
40
+ else:
41
+ _vae = self.vae_b2a
42
+ assert _vae.encoder.current_down_blocks is not None
43
+ _vae.decoder.incoming_skip_acts = _vae.encoder.current_down_blocks
44
+ x_decoded = (_vae.decode(x / _vae.config.scaling_factor).sample).clamp(-1, 1)
45
+ return x_decoded
46
+
47
+
48
+ def initialize_unet(rank, return_lora_module_names=False):
49
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
50
+ unet.requires_grad_(False)
51
+ unet.train()
52
+ l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
53
+ l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
54
+ for n, p in unet.named_parameters():
55
+ if "bias" in n or "norm" in n: continue
56
+ for pattern in l_grep:
57
+ if pattern in n and ("down_blocks" in n or "conv_in" in n):
58
+ l_target_modules_encoder.append(n.replace(".weight",""))
59
+ break
60
+ elif pattern in n and "up_blocks" in n:
61
+ l_target_modules_decoder.append(n.replace(".weight",""))
62
+ break
63
+ elif pattern in n:
64
+ l_modules_others.append(n.replace(".weight",""))
65
+ break
66
+ lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder, lora_alpha=rank)
67
+ lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder, lora_alpha=rank)
68
+ lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others, lora_alpha=rank)
69
+ unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
70
+ unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
71
+ unet.add_adapter(lora_conf_others, adapter_name="default_others")
72
+ unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
73
+ if return_lora_module_names:
74
+ return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
75
+ else:
76
+ return unet
77
+
78
+
79
+ def initialize_vae(rank=4, return_lora_module_names=False):
80
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
81
+ vae.requires_grad_(False)
82
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
83
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
84
+ vae.requires_grad_(True)
85
+ vae.train()
86
+ # add the skip connection convs
87
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
88
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
89
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
90
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
91
+ torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
92
+ torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
93
+ torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
94
+ torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
95
+ vae.decoder.ignore_skip = False
96
+ vae.decoder.gamma = 1
97
+ l_vae_target_modules = ["conv1","conv2","conv_in", "conv_shortcut",
98
+ "conv", "conv_out", "skip_conv_1", "skip_conv_2", "skip_conv_3",
99
+ "skip_conv_4", "to_k", "to_q", "to_v", "to_out.0",
100
+ ]
101
+ vae_lora_config = LoraConfig(r=rank, init_lora_weights="gaussian", target_modules=l_vae_target_modules)
102
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
103
+ if return_lora_module_names:
104
+ return vae, l_vae_target_modules
105
+ else:
106
+ return vae
107
+
108
+
109
+ class CycleGAN_Turbo(torch.nn.Module):
110
+ def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
111
+ super().__init__()
112
+ self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
113
+ self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
114
+ self.sched = make_1step_sched()
115
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
116
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
117
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
118
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
119
+ # add the skip connection convs
120
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
121
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
122
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
123
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
124
+ vae.decoder.ignore_skip = False
125
+ self.unet, self.vae = unet, vae
126
+ if pretrained_name == "day_to_night":
127
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/day2night.pkl"
128
+ self.load_ckpt_from_url(url, ckpt_folder)
129
+ self.timesteps = torch.tensor([999], device="cuda").long()
130
+ self.caption = "driving in the night"
131
+ self.direction = "a2b"
132
+ elif pretrained_name == "night_to_day":
133
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/night2day.pkl"
134
+ self.load_ckpt_from_url(url, ckpt_folder)
135
+ self.timesteps = torch.tensor([999], device="cuda").long()
136
+ self.caption = "driving in the day"
137
+ self.direction = "b2a"
138
+ elif pretrained_name == "clear_to_rainy":
139
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/clear2rainy.pkl"
140
+ self.load_ckpt_from_url(url, ckpt_folder)
141
+ self.timesteps = torch.tensor([999], device="cuda").long()
142
+ self.caption = "driving in heavy rain"
143
+ self.direction = "a2b"
144
+ elif pretrained_name == "rainy_to_clear":
145
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/rainy2clear.pkl"
146
+ self.load_ckpt_from_url(url, ckpt_folder)
147
+ self.timesteps = torch.tensor([999], device="cuda").long()
148
+ self.caption = "driving in the day"
149
+ self.direction = "b2a"
150
+
151
+ elif pretrained_path is not None:
152
+ sd = torch.load(pretrained_path)
153
+ self.load_ckpt_from_state_dict(sd)
154
+ self.timesteps = torch.tensor([999], device="cuda").long()
155
+ self.caption = None
156
+ self.direction = None
157
+
158
+ self.vae_enc.cuda()
159
+ self.vae_dec.cuda()
160
+ self.unet.cuda()
161
+
162
+ def load_ckpt_from_state_dict(self, sd):
163
+ lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"])
164
+ lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"])
165
+ lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"])
166
+ self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
167
+ self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
168
+ self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
169
+ for n, p in self.unet.named_parameters():
170
+ name_sd = n.replace(".default_encoder.weight", ".weight")
171
+ if "lora" in n and "default_encoder" in n:
172
+ p.data.copy_(sd["sd_encoder"][name_sd])
173
+ for n, p in self.unet.named_parameters():
174
+ name_sd = n.replace(".default_decoder.weight", ".weight")
175
+ if "lora" in n and "default_decoder" in n:
176
+ p.data.copy_(sd["sd_decoder"][name_sd])
177
+ for n, p in self.unet.named_parameters():
178
+ name_sd = n.replace(".default_others.weight", ".weight")
179
+ if "lora" in n and "default_others" in n:
180
+ p.data.copy_(sd["sd_other"][name_sd])
181
+ self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"])
182
+
183
+ vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
184
+ self.vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
185
+ self.vae.decoder.gamma = 1
186
+ self.vae_b2a = copy.deepcopy(self.vae)
187
+ self.vae_enc = VAE_encode(self.vae, vae_b2a=self.vae_b2a)
188
+ self.vae_enc.load_state_dict(sd["sd_vae_enc"])
189
+ self.vae_dec = VAE_decode(self.vae, vae_b2a=self.vae_b2a)
190
+ self.vae_dec.load_state_dict(sd["sd_vae_dec"])
191
+
192
+ def load_ckpt_from_url(self, url, ckpt_folder):
193
+ os.makedirs(ckpt_folder, exist_ok=True)
194
+ outf = os.path.join(ckpt_folder, os.path.basename(url))
195
+ download_url(url, outf)
196
+ sd = torch.load(outf)
197
+ self.load_ckpt_from_state_dict(sd)
198
+
199
+ @staticmethod
200
+ def forward_with_networks(x, direction, vae_enc, unet, vae_dec, sched, timesteps, text_emb):
201
+ B = x.shape[0]
202
+ assert direction in ["a2b", "b2a"]
203
+ x_enc = vae_enc(x, direction=direction).to(x.dtype)
204
+ model_pred = unet(x_enc, timesteps, encoder_hidden_states=text_emb,).sample
205
+ x_out = torch.stack([sched.step(model_pred[i], timesteps[i], x_enc[i], return_dict=True).prev_sample for i in range(B)])
206
+ x_out_decoded = vae_dec(x_out, direction=direction)
207
+ return x_out_decoded
208
+
209
+ @staticmethod
210
+ def get_traininable_params(unet, vae_a2b, vae_b2a):
211
+ # add all unet parameters
212
+ params_gen = list(unet.conv_in.parameters())
213
+ unet.conv_in.requires_grad_(True)
214
+ unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
215
+ for n,p in unet.named_parameters():
216
+ if "lora" in n and "default" in n:
217
+ assert p.requires_grad
218
+ params_gen.append(p)
219
+
220
+ # add all vae_a2b parameters
221
+ for n,p in vae_a2b.named_parameters():
222
+ if "lora" in n and "vae_skip" in n:
223
+ assert p.requires_grad
224
+ params_gen.append(p)
225
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_1.parameters())
226
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_2.parameters())
227
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_3.parameters())
228
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_4.parameters())
229
+
230
+ # add all vae_b2a parameters
231
+ for n,p in vae_b2a.named_parameters():
232
+ if "lora" in n and "vae_skip" in n:
233
+ assert p.requires_grad
234
+ params_gen.append(p)
235
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_1.parameters())
236
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_2.parameters())
237
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_3.parameters())
238
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_4.parameters())
239
+ return params_gen
240
+
241
+ def forward(self, x_t, direction=None, caption=None, caption_emb=None):
242
+ if direction is None:
243
+ assert self.direction is not None
244
+ direction = self.direction
245
+ if caption is None and caption_emb is None:
246
+ assert self.caption is not None
247
+ caption = self.caption
248
+ if caption_emb is not None:
249
+ caption_enc = caption_emb
250
+ else:
251
+ caption_tokens = self.tokenizer(caption, max_length=self.tokenizer.model_max_length,
252
+ padding="max_length", truncation=True, return_tensors="pt").input_ids.to(x_t.device)
253
+ caption_enc = self.text_encoder(caption_tokens)[0].detach().clone()
254
+ return self.forward_with_networks(x_t, direction, self.vae_enc, self.unet, self.vae_dec, self.sched, self.timesteps, caption_enc)
src/image_prep.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import cv2
4
+
5
+
6
+ def canny_from_pil(image, low_threshold=100, high_threshold=200):
7
+ image = np.array(image)
8
+ image = cv2.Canny(image, low_threshold, high_threshold)
9
+ image = image[:, :, None]
10
+ image = np.concatenate([image, image, image], axis=2)
11
+ control_image = Image.fromarray(image)
12
+ return control_image
src/inference_paired.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision import transforms
7
+ import torchvision.transforms.functional as F
8
+ from pix2pix_turbo import Pix2Pix_Turbo
9
+ from image_prep import canny_from_pil
10
+
11
+ if __name__ == "__main__":
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
14
+ parser.add_argument('--prompt', type=str, required=True, help='the prompt to be used')
15
+ parser.add_argument('--model_name', type=str, default='', help='name of the pretrained model to be used')
16
+ parser.add_argument('--model_path', type=str, default='', help='path to a model state dict to be used')
17
+ parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
18
+ parser.add_argument('--low_threshold', type=int, default=100, help='Canny low threshold')
19
+ parser.add_argument('--high_threshold', type=int, default=200, help='Canny high threshold')
20
+ parser.add_argument('--gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')
21
+ parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
22
+ args = parser.parse_args()
23
+
24
+ # only one of model_name and model_path should be provided
25
+ if args.model_name == '' != args.model_path == '':
26
+ raise ValueError('Either model_name or model_path should be provided')
27
+
28
+ os.makedirs(args.output_dir, exist_ok=True)
29
+
30
+ # initialize the model
31
+ model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
32
+ model.set_eval()
33
+
34
+ # make sure that the input image is a multiple of 8
35
+ input_image = Image.open(args.input_image).convert('RGB')
36
+ new_width = input_image.width - input_image.width % 8
37
+ new_height = input_image.height - input_image.height % 8
38
+ input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
39
+ bname = os.path.basename(args.input_image)
40
+
41
+ # translate the image
42
+ with torch.no_grad():
43
+ if args.model_name == 'edge_to_image':
44
+ canny = canny_from_pil(input_image, args.low_threshold, args.high_threshold)
45
+ canny_viz_inv = Image.fromarray(255 - np.array(canny))
46
+ canny_viz_inv.save(os.path.join(args.output_dir, bname.replace('.png', '_canny.png')))
47
+ c_t = F.to_tensor(canny).unsqueeze(0).cuda()
48
+ output_image = model(c_t, args.prompt)
49
+
50
+ elif args.model_name == 'sketch_to_image_stochastic':
51
+ image_t = F.to_tensor(input_image) < 0.5
52
+ c_t = image_t.unsqueeze(0).cuda().float()
53
+ torch.manual_seed(args.seed)
54
+ B, C, H, W = c_t.shape
55
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
56
+ output_image = model(c_t, args.prompt, deterministic=False, r=args.gamma, noise_map=noise)
57
+
58
+ else:
59
+ c_t = F.to_tensor(input_image).unsqueeze(0).cuda()
60
+ output_image = model(c_t, args.prompt)
61
+
62
+ output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
63
+
64
+ # save the output image
65
+ output_pil.save(os.path.join(args.output_dir, bname))
src/inference_unpaired.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision import transforms
6
+ from cyclegan_turbo import CycleGAN_Turbo
7
+ from my_utils.training_utils import build_transform
8
+
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
13
+ parser.add_argument('--prompt', type=str, required=False, help='the prompt to be used. It is required when loading a custom model_path.')
14
+ parser.add_argument('--model_name', type=str, default=None, help='name of the pretrained model to be used')
15
+ parser.add_argument('--model_path', type=str, default=None, help='path to a local model state dict to be used')
16
+ parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
17
+ parser.add_argument('--image_prep', type=str, default='resize_512x512', help='the image preparation method')
18
+ parser.add_argument('--direction', type=str, default=None, help='the direction of translation. None for pretrained models, a2b or b2a for custom paths.')
19
+ args = parser.parse_args()
20
+
21
+ # only one of model_name and model_path should be provided
22
+ if args.model_name is None != args.model_path is None:
23
+ raise ValueError('Either model_name or model_path should be provided')
24
+
25
+ if args.model_path is not None and args.prompt is None:
26
+ raise ValueError('prompt is required when loading a custom model_path.')
27
+
28
+ if args.model_name is not None:
29
+ assert args.prompt is None, 'prompt is not required when loading a pretrained model.'
30
+ assert args.direction is None, 'direction is not required when loading a pretrained model.'
31
+
32
+ # initialize the model
33
+ model = CycleGAN_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
34
+ model.eval()
35
+ model.unet.enable_xformers_memory_efficient_attention()
36
+
37
+ T_val = build_transform(args.image_prep)
38
+
39
+ input_image = Image.open(args.input_image).convert('RGB')
40
+ # translate the image
41
+ with torch.no_grad():
42
+ input_img = T_val(input_image)
43
+ x_t = transforms.ToTensor()(input_img)
44
+ x_t = transforms.Normalize([0.5], [0.5])(x_t).unsqueeze(0).cuda()
45
+ output = model(x_t, direction=args.direction, caption=args.prompt)
46
+
47
+ output_pil = transforms.ToPILImage()(output[0].cpu() * 0.5 + 0.5)
48
+ output_pil = output_pil.resize((input_image.width, input_image.height), Image.LANCZOS)
49
+
50
+ # save the output image
51
+ bname = os.path.basename(args.input_image)
52
+ os.makedirs(args.output_dir, exist_ok=True)
53
+ output_pil.save(os.path.join(args.output_dir, bname))
src/model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ from diffusers import DDPMScheduler
5
+
6
+
7
+ def make_1step_sched():
8
+ noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
9
+ noise_scheduler_1step.set_timesteps(1, device="cuda")
10
+ noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
11
+ return noise_scheduler_1step
12
+
13
+
14
+ def my_vae_encoder_fwd(self, sample):
15
+ sample = self.conv_in(sample)
16
+ l_blocks = []
17
+ # down
18
+ for down_block in self.down_blocks:
19
+ l_blocks.append(sample)
20
+ sample = down_block(sample)
21
+ # middle
22
+ sample = self.mid_block(sample)
23
+ sample = self.conv_norm_out(sample)
24
+ sample = self.conv_act(sample)
25
+ sample = self.conv_out(sample)
26
+ self.current_down_blocks = l_blocks
27
+ return sample
28
+
29
+
30
+ def my_vae_decoder_fwd(self, sample, latent_embeds=None):
31
+ sample = self.conv_in(sample)
32
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
33
+ # middle
34
+ sample = self.mid_block(sample, latent_embeds)
35
+ sample = sample.to(upscale_dtype)
36
+ if not self.ignore_skip:
37
+ skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
38
+ # up
39
+ for idx, up_block in enumerate(self.up_blocks):
40
+ skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
41
+ # add skip
42
+ sample = sample + skip_in
43
+ sample = up_block(sample, latent_embeds)
44
+ else:
45
+ for idx, up_block in enumerate(self.up_blocks):
46
+ sample = up_block(sample, latent_embeds)
47
+ # post-process
48
+ if latent_embeds is None:
49
+ sample = self.conv_norm_out(sample)
50
+ else:
51
+ sample = self.conv_norm_out(sample, latent_embeds)
52
+ sample = self.conv_act(sample)
53
+ sample = self.conv_out(sample)
54
+ return sample
55
+
56
+
57
+ def download_url(url, outf):
58
+ if not os.path.exists(outf):
59
+ print(f"Downloading checkpoint to {outf}")
60
+ response = requests.get(url, stream=True)
61
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
62
+ block_size = 1024 # 1 Kibibyte
63
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
64
+ with open(outf, 'wb') as file:
65
+ for data in response.iter_content(block_size):
66
+ progress_bar.update(len(data))
67
+ file.write(data)
68
+ progress_bar.close()
69
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
70
+ print("ERROR, something went wrong")
71
+ print(f"Downloaded successfully to {outf}")
72
+ else:
73
+ print(f"Skipping download, {outf} already exists")
src/my_utils/dino_struct.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def attn_cosine_sim(x, eps=1e-08):
7
+ x = x[0] # TEMP: getting rid of redundant dimension, TBF
8
+ norm1 = x.norm(dim=2, keepdim=True)
9
+ factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
10
+ sim_matrix = (x @ x.permute(0, 2, 1)) / factor
11
+ return sim_matrix
12
+
13
+
14
+ class VitExtractor:
15
+ BLOCK_KEY = 'block'
16
+ ATTN_KEY = 'attn'
17
+ PATCH_IMD_KEY = 'patch_imd'
18
+ QKV_KEY = 'qkv'
19
+ KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
20
+
21
+ def __init__(self, model_name, device):
22
+ # pdb.set_trace()
23
+ self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
24
+ self.model.eval()
25
+ self.model_name = model_name
26
+ self.hook_handlers = []
27
+ self.layers_dict = {}
28
+ self.outputs_dict = {}
29
+ for key in VitExtractor.KEY_LIST:
30
+ self.layers_dict[key] = []
31
+ self.outputs_dict[key] = []
32
+ self._init_hooks_data()
33
+
34
+ def _init_hooks_data(self):
35
+ self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
36
+ self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
37
+ self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
38
+ self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
39
+ for key in VitExtractor.KEY_LIST:
40
+ # self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
41
+ self.outputs_dict[key] = []
42
+
43
+ def _register_hooks(self, **kwargs):
44
+ for block_idx, block in enumerate(self.model.blocks):
45
+ if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
46
+ self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
47
+ if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
48
+ self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
49
+ if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
50
+ self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
51
+ if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
52
+ self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
53
+
54
+ def _clear_hooks(self):
55
+ for handler in self.hook_handlers:
56
+ handler.remove()
57
+ self.hook_handlers = []
58
+
59
+ def _get_block_hook(self):
60
+ def _get_block_output(model, input, output):
61
+ self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
62
+
63
+ return _get_block_output
64
+
65
+ def _get_attn_hook(self):
66
+ def _get_attn_output(model, inp, output):
67
+ self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
68
+
69
+ return _get_attn_output
70
+
71
+ def _get_qkv_hook(self):
72
+ def _get_qkv_output(model, inp, output):
73
+ self.outputs_dict[VitExtractor.QKV_KEY].append(output)
74
+
75
+ return _get_qkv_output
76
+
77
+ # TODO: CHECK ATTN OUTPUT TUPLE
78
+ def _get_patch_imd_hook(self):
79
+ def _get_attn_output(model, inp, output):
80
+ self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
81
+
82
+ return _get_attn_output
83
+
84
+ def get_feature_from_input(self, input_img): # List([B, N, D])
85
+ self._register_hooks()
86
+ self.model(input_img)
87
+ feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
88
+ self._clear_hooks()
89
+ self._init_hooks_data()
90
+ return feature
91
+
92
+ def get_qkv_feature_from_input(self, input_img):
93
+ self._register_hooks()
94
+ self.model(input_img)
95
+ feature = self.outputs_dict[VitExtractor.QKV_KEY]
96
+ self._clear_hooks()
97
+ self._init_hooks_data()
98
+ return feature
99
+
100
+ def get_attn_feature_from_input(self, input_img):
101
+ self._register_hooks()
102
+ self.model(input_img)
103
+ feature = self.outputs_dict[VitExtractor.ATTN_KEY]
104
+ self._clear_hooks()
105
+ self._init_hooks_data()
106
+ return feature
107
+
108
+ def get_patch_size(self):
109
+ return 8 if "8" in self.model_name else 16
110
+
111
+ def get_width_patch_num(self, input_img_shape):
112
+ b, c, h, w = input_img_shape
113
+ patch_size = self.get_patch_size()
114
+ return w // patch_size
115
+
116
+ def get_height_patch_num(self, input_img_shape):
117
+ b, c, h, w = input_img_shape
118
+ patch_size = self.get_patch_size()
119
+ return h // patch_size
120
+
121
+ def get_patch_num(self, input_img_shape):
122
+ patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
123
+ return patch_num
124
+
125
+ def get_head_num(self):
126
+ if "dino" in self.model_name:
127
+ return 6 if "s" in self.model_name else 12
128
+ return 6 if "small" in self.model_name else 12
129
+
130
+ def get_embedding_dim(self):
131
+ if "dino" in self.model_name:
132
+ return 384 if "s" in self.model_name else 768
133
+ return 384 if "small" in self.model_name else 768
134
+
135
+ def get_queries_from_qkv(self, qkv, input_img_shape):
136
+ patch_num = self.get_patch_num(input_img_shape)
137
+ head_num = self.get_head_num()
138
+ embedding_dim = self.get_embedding_dim()
139
+ q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
140
+ return q
141
+
142
+ def get_keys_from_qkv(self, qkv, input_img_shape):
143
+ patch_num = self.get_patch_num(input_img_shape)
144
+ head_num = self.get_head_num()
145
+ embedding_dim = self.get_embedding_dim()
146
+ k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
147
+ return k
148
+
149
+ def get_values_from_qkv(self, qkv, input_img_shape):
150
+ patch_num = self.get_patch_num(input_img_shape)
151
+ head_num = self.get_head_num()
152
+ embedding_dim = self.get_embedding_dim()
153
+ v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
154
+ return v
155
+
156
+ def get_keys_from_input(self, input_img, layer_num):
157
+ qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
158
+ keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
159
+ return keys
160
+
161
+ def get_keys_self_sim_from_input(self, input_img, layer_num):
162
+ keys = self.get_keys_from_input(input_img, layer_num=layer_num)
163
+ h, t, d = keys.shape
164
+ concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
165
+ ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
166
+ return ssim_map
167
+
168
+
169
+ class DinoStructureLoss:
170
+ def __init__(self, ):
171
+ self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda")
172
+ self.preprocess = torchvision.transforms.Compose([
173
+ torchvision.transforms.Resize(224),
174
+ torchvision.transforms.ToTensor(),
175
+ torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
176
+ ])
177
+
178
+ def calculate_global_ssim_loss(self, outputs, inputs):
179
+ loss = 0.0
180
+ for a, b in zip(inputs, outputs): # avoid memory limitations
181
+ with torch.no_grad():
182
+ target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
183
+ keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
184
+ loss += F.mse_loss(keys_ssim, target_keys_self_sim)
185
+ return loss
src/my_utils/training_utils.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import argparse
4
+ import json
5
+ import torch
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ import torchvision.transforms.functional as F
9
+ from glob import glob
10
+
11
+
12
+ def parse_args_paired_training(input_args=None):
13
+ """
14
+ Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
15
+ This function sets up an argument parser to handle various training options.
16
+
17
+ Returns:
18
+ argparse.Namespace: The parsed command-line arguments.
19
+ """
20
+ parser = argparse.ArgumentParser()
21
+ # args for the loss function
22
+ parser.add_argument("--gan_disc_type", default="vagan_clip")
23
+ parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s")
24
+ parser.add_argument("--lambda_gan", default=0.5, type=float)
25
+ parser.add_argument("--lambda_lpips", default=5, type=float)
26
+ parser.add_argument("--lambda_l2", default=1.0, type=float)
27
+ parser.add_argument("--lambda_clipsim", default=5.0, type=float)
28
+
29
+ # dataset options
30
+ parser.add_argument("--dataset_folder", required=True, type=str)
31
+ parser.add_argument("--train_image_prep", default="resized_crop_512", type=str)
32
+ parser.add_argument("--test_image_prep", default="resized_crop_512", type=str)
33
+
34
+ # validation eval args
35
+ parser.add_argument("--eval_freq", default=100, type=int)
36
+ parser.add_argument("--track_val_fid", default=False, action="store_true")
37
+ parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation")
38
+
39
+ parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.")
40
+ parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
41
+
42
+ # details about the model architecture
43
+ parser.add_argument("--pretrained_model_name_or_path")
44
+ parser.add_argument("--revision", type=str, default=None,)
45
+ parser.add_argument("--variant", type=str, default=None,)
46
+ parser.add_argument("--tokenizer_name", type=str, default=None)
47
+ parser.add_argument("--lora_rank_unet", default=8, type=int)
48
+ parser.add_argument("--lora_rank_vae", default=4, type=int)
49
+
50
+ # training details
51
+ parser.add_argument("--output_dir", required=True)
52
+ parser.add_argument("--cache_dir", default=None,)
53
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
54
+ parser.add_argument("--resolution", type=int, default=512,)
55
+ parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
56
+ parser.add_argument("--num_training_epochs", type=int, default=10)
57
+ parser.add_argument("--max_train_steps", type=int, default=10_000,)
58
+ parser.add_argument("--checkpointing_steps", type=int, default=500,)
59
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
60
+ parser.add_argument("--gradient_checkpointing", action="store_true",)
61
+ parser.add_argument("--learning_rate", type=float, default=5e-6)
62
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
63
+ help=(
64
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
65
+ ' "constant", "constant_with_warmup"]'
66
+ ),
67
+ )
68
+ parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
69
+ parser.add_argument("--lr_num_cycles", type=int, default=1,
70
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
71
+ )
72
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
73
+
74
+ parser.add_argument("--dataloader_num_workers", type=int, default=0,)
75
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
76
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
77
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
78
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
79
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
80
+ parser.add_argument("--allow_tf32", action="store_true",
81
+ help=(
82
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
+ ),
85
+ )
86
+ parser.add_argument("--report_to", type=str, default="wandb",
87
+ help=(
88
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
89
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
90
+ ),
91
+ )
92
+ parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
93
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
94
+ parser.add_argument("--set_grads_to_none", action="store_true",)
95
+
96
+ if input_args is not None:
97
+ args = parser.parse_args(input_args)
98
+ else:
99
+ args = parser.parse_args()
100
+
101
+ return args
102
+
103
+
104
+ def parse_args_unpaired_training():
105
+ """
106
+ Parses command-line arguments used for configuring an unpaired session (CycleGAN-Turbo).
107
+ This function sets up an argument parser to handle various training options.
108
+
109
+ Returns:
110
+ argparse.Namespace: The parsed command-line arguments.
111
+ """
112
+
113
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
114
+
115
+ # fixed random seed
116
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
117
+
118
+ # args for the loss function
119
+ parser.add_argument("--gan_disc_type", default="vagan_clip")
120
+ parser.add_argument("--gan_loss_type", default="multilevel_sigmoid")
121
+ parser.add_argument("--lambda_gan", default=0.5, type=float)
122
+ parser.add_argument("--lambda_idt", default=1, type=float)
123
+ parser.add_argument("--lambda_cycle", default=1, type=float)
124
+ parser.add_argument("--lambda_cycle_lpips", default=10.0, type=float)
125
+ parser.add_argument("--lambda_idt_lpips", default=1.0, type=float)
126
+
127
+ # args for dataset and dataloader options
128
+ parser.add_argument("--dataset_folder", required=True, type=str)
129
+ parser.add_argument("--train_img_prep", required=True)
130
+ parser.add_argument("--val_img_prep", required=True)
131
+ parser.add_argument("--dataloader_num_workers", type=int, default=0)
132
+ parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
133
+ parser.add_argument("--max_train_epochs", type=int, default=100)
134
+ parser.add_argument("--max_train_steps", type=int, default=None)
135
+
136
+ # args for the model
137
+ parser.add_argument("--pretrained_model_name_or_path", default="stabilityai/sd-turbo")
138
+ parser.add_argument("--revision", default=None, type=str)
139
+ parser.add_argument("--variant", default=None, type=str)
140
+ parser.add_argument("--lora_rank_unet", default=128, type=int)
141
+ parser.add_argument("--lora_rank_vae", default=4, type=int)
142
+
143
+ # args for validation and logging
144
+ parser.add_argument("--viz_freq", type=int, default=20)
145
+ parser.add_argument("--output_dir", type=str, required=True)
146
+ parser.add_argument("--report_to", type=str, default="wandb")
147
+ parser.add_argument("--tracker_project_name", type=str, required=True)
148
+ parser.add_argument("--validation_steps", type=int, default=500,)
149
+ parser.add_argument("--validation_num_images", type=int, default=-1, help="Number of images to use for validation. -1 to use all images.")
150
+ parser.add_argument("--checkpointing_steps", type=int, default=500)
151
+
152
+ # args for the optimization options
153
+ parser.add_argument("--learning_rate", type=float, default=5e-6,)
154
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
155
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
156
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
157
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
158
+ parser.add_argument("--max_grad_norm", default=10.0, type=float, help="Max gradient norm.")
159
+ parser.add_argument("--lr_scheduler", type=str, default="constant", help=(
160
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
161
+ ' "constant", "constant_with_warmup"]'
162
+ ),
163
+ )
164
+ parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
165
+ parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.",)
166
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
167
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
168
+
169
+ # memory saving options
170
+ parser.add_argument("--allow_tf32", action="store_true",
171
+ help=(
172
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
173
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
174
+ ),
175
+ )
176
+ parser.add_argument("--gradient_checkpointing", action="store_true",
177
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.")
178
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
179
+
180
+ args = parser.parse_args()
181
+ return args
182
+
183
+
184
+ def build_transform(image_prep):
185
+ """
186
+ Constructs a transformation pipeline based on the specified image preparation method.
187
+
188
+ Parameters:
189
+ - image_prep (str): A string describing the desired image preparation
190
+
191
+ Returns:
192
+ - torchvision.transforms.Compose: A composable sequence of transformations to be applied to images.
193
+ """
194
+ if image_prep == "resized_crop_512":
195
+ T = transforms.Compose([
196
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
197
+ transforms.CenterCrop(512),
198
+ ])
199
+ elif image_prep == "resize_286_randomcrop_256x256_hflip":
200
+ T = transforms.Compose([
201
+ transforms.Resize((286, 286), interpolation=Image.LANCZOS),
202
+ transforms.RandomCrop((256, 256)),
203
+ transforms.RandomHorizontalFlip(),
204
+ ])
205
+ elif image_prep in ["resize_256", "resize_256x256"]:
206
+ T = transforms.Compose([
207
+ transforms.Resize((256, 256), interpolation=Image.LANCZOS)
208
+ ])
209
+ elif image_prep in ["resize_512", "resize_512x512"]:
210
+ T = transforms.Compose([
211
+ transforms.Resize((512, 512), interpolation=Image.LANCZOS)
212
+ ])
213
+ elif image_prep == "no_resize":
214
+ T = transforms.Lambda(lambda x: x)
215
+ return T
216
+
217
+
218
+ class PairedDataset(torch.utils.data.Dataset):
219
+ def __init__(self, dataset_folder, split, image_prep, tokenizer):
220
+ """
221
+ Itialize the paired dataset object for loading and transforming paired data samples
222
+ from specified dataset folders.
223
+
224
+ This constructor sets up the paths to input and output folders based on the specified 'split',
225
+ loads the captions (or prompts) for the input images, and prepares the transformations and
226
+ tokenizer to be applied on the data.
227
+
228
+ Parameters:
229
+ - dataset_folder (str): The root folder containing the dataset, expected to include
230
+ sub-folders for different splits (e.g., 'train_A', 'train_B').
231
+ - split (str): The dataset split to use ('train' or 'test'), used to select the appropriate
232
+ sub-folders and caption files within the dataset folder.
233
+ - image_prep (str): The image preprocessing transformation to apply to each image.
234
+ - tokenizer: The tokenizer used for tokenizing the captions (or prompts).
235
+ """
236
+ super().__init__()
237
+ if split == "train":
238
+ self.input_folder = os.path.join(dataset_folder, "train_A")
239
+ self.output_folder = os.path.join(dataset_folder, "train_B")
240
+ captions = os.path.join(dataset_folder, "train_prompts.json")
241
+ elif split == "test":
242
+ self.input_folder = os.path.join(dataset_folder, "test_A")
243
+ self.output_folder = os.path.join(dataset_folder, "test_B")
244
+ captions = os.path.join(dataset_folder, "test_prompts.json")
245
+ with open(captions, "r") as f:
246
+ self.captions = json.load(f)
247
+ self.img_names = list(self.captions.keys())
248
+ self.T = build_transform(image_prep)
249
+ self.tokenizer = tokenizer
250
+
251
+ def __len__(self):
252
+ """
253
+ Returns:
254
+ int: The total number of items in the dataset.
255
+ """
256
+ return len(self.captions)
257
+
258
+ def __getitem__(self, idx):
259
+ """
260
+ Retrieves a dataset item given its index. Each item consists of an input image,
261
+ its corresponding output image, the captions associated with the input image,
262
+ and the tokenized form of this caption.
263
+
264
+ This method performs the necessary preprocessing on both the input and output images,
265
+ including scaling and normalization, as well as tokenizing the caption using a provided tokenizer.
266
+
267
+ Parameters:
268
+ - idx (int): The index of the item to retrieve.
269
+
270
+ Returns:
271
+ dict: A dictionary containing the following key-value pairs:
272
+ - "output_pixel_values": a tensor of the preprocessed output image with pixel values
273
+ scaled to [-1, 1].
274
+ - "conditioning_pixel_values": a tensor of the preprocessed input image with pixel values
275
+ scaled to [0, 1].
276
+ - "caption": the text caption.
277
+ - "input_ids": a tensor of the tokenized caption.
278
+
279
+ Note:
280
+ The actual preprocessing steps (scaling and normalization) for images are defined externally
281
+ and passed to this class through the `image_prep` parameter during initialization. The
282
+ tokenization process relies on the `tokenizer` also provided at initialization, which
283
+ should be compatible with the models intended to be used with this dataset.
284
+ """
285
+ img_name = self.img_names[idx]
286
+ input_img = Image.open(os.path.join(self.input_folder, img_name))
287
+ output_img = Image.open(os.path.join(self.output_folder, img_name))
288
+ caption = self.captions[img_name]
289
+
290
+ # input images scaled to 0,1
291
+ img_t = self.T(input_img)
292
+ img_t = F.to_tensor(img_t)
293
+ # output images scaled to -1,1
294
+ output_t = self.T(output_img)
295
+ output_t = F.to_tensor(output_t)
296
+ output_t = F.normalize(output_t, mean=[0.5], std=[0.5])
297
+
298
+ input_ids = self.tokenizer(
299
+ caption, max_length=self.tokenizer.model_max_length,
300
+ padding="max_length", truncation=True, return_tensors="pt"
301
+ ).input_ids
302
+
303
+ return {
304
+ "output_pixel_values": output_t,
305
+ "conditioning_pixel_values": img_t,
306
+ "caption": caption,
307
+ "input_ids": input_ids,
308
+ }
309
+
310
+
311
+ class UnpairedDataset(torch.utils.data.Dataset):
312
+ def __init__(self, dataset_folder, split, image_prep, tokenizer):
313
+ """
314
+ A dataset class for loading unpaired data samples from two distinct domains (source and target),
315
+ typically used in unsupervised learning tasks like image-to-image translation.
316
+
317
+ The class supports loading images from specified dataset folders, applying predefined image
318
+ preprocessing transformations, and utilizing fixed textual prompts (captions) for each domain,
319
+ tokenized using a provided tokenizer.
320
+
321
+ Parameters:
322
+ - dataset_folder (str): Base directory of the dataset containing subdirectories (train_A, train_B, test_A, test_B)
323
+ - split (str): Indicates the dataset split to use. Expected values are 'train' or 'test'.
324
+ - image_prep (str): he image preprocessing transformation to apply to each image.
325
+ - tokenizer: The tokenizer used for tokenizing the captions (or prompts).
326
+ """
327
+ super().__init__()
328
+ if split == "train":
329
+ self.source_folder = os.path.join(dataset_folder, "train_A")
330
+ self.target_folder = os.path.join(dataset_folder, "train_B")
331
+ elif split == "test":
332
+ self.source_folder = os.path.join(dataset_folder, "test_A")
333
+ self.target_folder = os.path.join(dataset_folder, "test_B")
334
+ self.tokenizer = tokenizer
335
+ with open(os.path.join(dataset_folder, "fixed_prompt_a.txt"), "r") as f:
336
+ self.fixed_caption_src = f.read().strip()
337
+ self.input_ids_src = self.tokenizer(
338
+ self.fixed_caption_src, max_length=self.tokenizer.model_max_length,
339
+ padding="max_length", truncation=True, return_tensors="pt"
340
+ ).input_ids
341
+
342
+ with open(os.path.join(dataset_folder, "fixed_prompt_b.txt"), "r") as f:
343
+ self.fixed_caption_tgt = f.read().strip()
344
+ self.input_ids_tgt = self.tokenizer(
345
+ self.fixed_caption_tgt, max_length=self.tokenizer.model_max_length,
346
+ padding="max_length", truncation=True, return_tensors="pt"
347
+ ).input_ids
348
+ # find all images in the source and target folders with all IMG extensions
349
+ self.l_imgs_src = []
350
+ for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]:
351
+ self.l_imgs_src.extend(glob(os.path.join(self.source_folder, ext)))
352
+ self.l_imgs_tgt = []
353
+ for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]:
354
+ self.l_imgs_tgt.extend(glob(os.path.join(self.target_folder, ext)))
355
+ self.T = build_transform(image_prep)
356
+
357
+ def __len__(self):
358
+ """
359
+ Returns:
360
+ int: The total number of items in the dataset.
361
+ """
362
+ return len(self.l_imgs_src) + len(self.l_imgs_tgt)
363
+
364
+ def __getitem__(self, index):
365
+ """
366
+ Fetches a pair of unaligned images from the source and target domains along with their
367
+ corresponding tokenized captions.
368
+
369
+ For the source domain, if the requested index is within the range of available images,
370
+ the specific image at that index is chosen. If the index exceeds the number of source
371
+ images, a random source image is selected. For the target domain,
372
+ an image is always randomly selected, irrespective of the index, to maintain the
373
+ unpaired nature of the dataset.
374
+
375
+ Both images are preprocessed according to the specified image transformation `T`, and normalized.
376
+ The fixed captions for both domains
377
+ are included along with their tokenized forms.
378
+
379
+ Parameters:
380
+ - index (int): The index of the source image to retrieve.
381
+
382
+ Returns:
383
+ dict: A dictionary containing processed data for a single training example, with the following keys:
384
+ - "pixel_values_src": The processed source image
385
+ - "pixel_values_tgt": The processed target image
386
+ - "caption_src": The fixed caption of the source domain.
387
+ - "caption_tgt": The fixed caption of the target domain.
388
+ - "input_ids_src": The source domain's fixed caption tokenized.
389
+ - "input_ids_tgt": The target domain's fixed caption tokenized.
390
+ """
391
+ if index < len(self.l_imgs_src):
392
+ img_path_src = self.l_imgs_src[index]
393
+ else:
394
+ img_path_src = random.choice(self.l_imgs_src)
395
+ img_path_tgt = random.choice(self.l_imgs_tgt)
396
+ img_pil_src = Image.open(img_path_src).convert("RGB")
397
+ img_pil_tgt = Image.open(img_path_tgt).convert("RGB")
398
+ img_t_src = F.to_tensor(self.T(img_pil_src))
399
+ img_t_tgt = F.to_tensor(self.T(img_pil_tgt))
400
+ img_t_src = F.normalize(img_t_src, mean=[0.5], std=[0.5])
401
+ img_t_tgt = F.normalize(img_t_tgt, mean=[0.5], std=[0.5])
402
+ return {
403
+ "pixel_values_src": img_t_src,
404
+ "pixel_values_tgt": img_t_tgt,
405
+ "caption_src": self.fixed_caption_src,
406
+ "caption_tgt": self.fixed_caption_tgt,
407
+ "input_ids_src": self.input_ids_src,
408
+ "input_ids_tgt": self.input_ids_tgt,
409
+ }
src/pix2pix_turbo.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import sys
4
+ import copy
5
+ from tqdm import tqdm
6
+ import torch
7
+ from transformers import AutoTokenizer, CLIPTextModel
8
+ from diffusers import AutoencoderKL, UNet2DConditionModel
9
+ from diffusers.utils.peft_utils import set_weights_and_activate_adapters
10
+ from peft import LoraConfig
11
+ p = "src/"
12
+ sys.path.append(p)
13
+ from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd
14
+
15
+
16
+ class TwinConv(torch.nn.Module):
17
+ def __init__(self, convin_pretrained, convin_curr):
18
+ super(TwinConv, self).__init__()
19
+ self.conv_in_pretrained = copy.deepcopy(convin_pretrained)
20
+ self.conv_in_curr = copy.deepcopy(convin_curr)
21
+ self.r = None
22
+
23
+ def forward(self, x):
24
+ x1 = self.conv_in_pretrained(x).detach()
25
+ x2 = self.conv_in_curr(x)
26
+ return x1 * (1 - self.r) + x2 * (self.r)
27
+
28
+
29
+ class Pix2Pix_Turbo(torch.nn.Module):
30
+ def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
31
+ super().__init__()
32
+ self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
33
+ self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
34
+ self.sched = make_1step_sched()
35
+
36
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
37
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
38
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
39
+ # add the skip connection convs
40
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
41
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
42
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
43
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
44
+ vae.decoder.ignore_skip = False
45
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
46
+
47
+ if pretrained_name == "edge_to_image":
48
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl"
49
+ os.makedirs(ckpt_folder, exist_ok=True)
50
+ outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl")
51
+ if not os.path.exists(outf):
52
+ print(f"Downloading checkpoint to {outf}")
53
+ response = requests.get(url, stream=True)
54
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
55
+ block_size = 1024 # 1 Kibibyte
56
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
57
+ with open(outf, 'wb') as file:
58
+ for data in response.iter_content(block_size):
59
+ progress_bar.update(len(data))
60
+ file.write(data)
61
+ progress_bar.close()
62
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
63
+ print("ERROR, something went wrong")
64
+ print(f"Downloaded successfully to {outf}")
65
+ p_ckpt = outf
66
+ sd = torch.load(p_ckpt, map_location="cpu")
67
+ unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
68
+ vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
69
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
70
+ _sd_vae = vae.state_dict()
71
+ for k in sd["state_dict_vae"]:
72
+ _sd_vae[k] = sd["state_dict_vae"][k]
73
+ vae.load_state_dict(_sd_vae)
74
+ unet.add_adapter(unet_lora_config)
75
+ _sd_unet = unet.state_dict()
76
+ for k in sd["state_dict_unet"]:
77
+ _sd_unet[k] = sd["state_dict_unet"][k]
78
+ unet.load_state_dict(_sd_unet)
79
+
80
+ elif pretrained_name == "sketch_to_image_stochastic":
81
+ # download from url
82
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl"
83
+ os.makedirs(ckpt_folder, exist_ok=True)
84
+ outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl")
85
+ if not os.path.exists(outf):
86
+ print(f"Downloading checkpoint to {outf}")
87
+ response = requests.get(url, stream=True)
88
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
89
+ block_size = 1024 # 1 Kibibyte
90
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
91
+ with open(outf, 'wb') as file:
92
+ for data in response.iter_content(block_size):
93
+ progress_bar.update(len(data))
94
+ file.write(data)
95
+ progress_bar.close()
96
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
97
+ print("ERROR, something went wrong")
98
+ print(f"Downloaded successfully to {outf}")
99
+ p_ckpt = outf
100
+ convin_pretrained = copy.deepcopy(unet.conv_in)
101
+ unet.conv_in = TwinConv(convin_pretrained, unet.conv_in)
102
+ sd = torch.load(p_ckpt, map_location="cpu")
103
+ unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
104
+ vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
105
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
106
+ _sd_vae = vae.state_dict()
107
+ for k in sd["state_dict_vae"]:
108
+ _sd_vae[k] = sd["state_dict_vae"][k]
109
+ vae.load_state_dict(_sd_vae)
110
+ unet.add_adapter(unet_lora_config)
111
+ _sd_unet = unet.state_dict()
112
+ for k in sd["state_dict_unet"]:
113
+ _sd_unet[k] = sd["state_dict_unet"][k]
114
+ unet.load_state_dict(_sd_unet)
115
+
116
+ elif pretrained_path is not None:
117
+ sd = torch.load(pretrained_path, map_location="cpu")
118
+ unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
119
+ vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
120
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
121
+ _sd_vae = vae.state_dict()
122
+ for k in sd["state_dict_vae"]:
123
+ _sd_vae[k] = sd["state_dict_vae"][k]
124
+ vae.load_state_dict(_sd_vae)
125
+ unet.add_adapter(unet_lora_config)
126
+ _sd_unet = unet.state_dict()
127
+ for k in sd["state_dict_unet"]:
128
+ _sd_unet[k] = sd["state_dict_unet"][k]
129
+ unet.load_state_dict(_sd_unet)
130
+
131
+ elif pretrained_name is None and pretrained_path is None:
132
+ print("Initializing model with random weights")
133
+ torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
134
+ torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
135
+ torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
136
+ torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
137
+ target_modules_vae = ["conv1", "conv2", "conv_in", "conv_shortcut", "conv", "conv_out",
138
+ "skip_conv_1", "skip_conv_2", "skip_conv_3", "skip_conv_4",
139
+ "to_k", "to_q", "to_v", "to_out.0",
140
+ ]
141
+ vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
142
+ target_modules=target_modules_vae)
143
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
144
+ target_modules_unet = [
145
+ "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
146
+ "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
147
+ ]
148
+ unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
149
+ target_modules=target_modules_unet
150
+ )
151
+ unet.add_adapter(unet_lora_config)
152
+ self.lora_rank_unet = lora_rank_unet
153
+ self.lora_rank_vae = lora_rank_vae
154
+ self.target_modules_vae = target_modules_vae
155
+ self.target_modules_unet = target_modules_unet
156
+
157
+ # unet.enable_xformers_memory_efficient_attention()
158
+ unet.to("cuda")
159
+ vae.to("cuda")
160
+ self.unet, self.vae = unet, vae
161
+ self.vae.decoder.gamma = 1
162
+ self.timesteps = torch.tensor([999], device="cuda").long()
163
+ self.text_encoder.requires_grad_(False)
164
+
165
+ def set_eval(self):
166
+ self.unet.eval()
167
+ self.vae.eval()
168
+ self.unet.requires_grad_(False)
169
+ self.vae.requires_grad_(False)
170
+
171
+ def set_train(self):
172
+ self.unet.train()
173
+ self.vae.train()
174
+ for n, _p in self.unet.named_parameters():
175
+ if "lora" in n:
176
+ _p.requires_grad = True
177
+ self.unet.conv_in.requires_grad_(True)
178
+ for n, _p in self.vae.named_parameters():
179
+ if "lora" in n:
180
+ _p.requires_grad = True
181
+ self.vae.decoder.skip_conv_1.requires_grad_(True)
182
+ self.vae.decoder.skip_conv_2.requires_grad_(True)
183
+ self.vae.decoder.skip_conv_3.requires_grad_(True)
184
+ self.vae.decoder.skip_conv_4.requires_grad_(True)
185
+
186
+ def forward(self, c_t, prompt=None, prompt_tokens=None, deterministic=True, r=1.0, noise_map=None):
187
+ # either the prompt or the prompt_tokens should be provided
188
+ assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
189
+
190
+ if prompt is not None:
191
+ # encode the text prompt
192
+ caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length,
193
+ padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
194
+ caption_enc = self.text_encoder(caption_tokens)[0]
195
+ else:
196
+ caption_enc = self.text_encoder(prompt_tokens)[0]
197
+ if deterministic:
198
+ encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
199
+ model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample
200
+ x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
201
+ self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
202
+ output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
203
+ else:
204
+ # scale the lora weights based on the r value
205
+ self.unet.set_adapters(["default"], weights=[r])
206
+ set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r])
207
+ encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
208
+ # combine the input and noise
209
+ unet_input = encoded_control * r + noise_map * (1 - r)
210
+ self.unet.conv_in.r = r
211
+ unet_output = self.unet(unet_input, self.timesteps, encoder_hidden_states=caption_enc,).sample
212
+ self.unet.conv_in.r = None
213
+ x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
214
+ self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
215
+ self.vae.decoder.gamma = r
216
+ output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
217
+ return output_image
218
+
219
+ def save_model(self, outf):
220
+ sd = {}
221
+ sd["unet_lora_target_modules"] = self.target_modules_unet
222
+ sd["vae_lora_target_modules"] = self.target_modules_vae
223
+ sd["rank_unet"] = self.lora_rank_unet
224
+ sd["rank_vae"] = self.lora_rank_vae
225
+ sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
226
+ sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip" in k}
227
+ torch.save(sd, outf)
src/train_cyclegan_turbo.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import copy
4
+ import lpips
5
+ import torch
6
+ import wandb
7
+ from glob import glob
8
+ import numpy as np
9
+ from accelerate import Accelerator
10
+ from accelerate.utils import set_seed
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from tqdm.auto import tqdm
14
+ from transformers import AutoTokenizer, CLIPTextModel
15
+ from diffusers.optimization import get_scheduler
16
+ from peft.utils import get_peft_model_state_dict
17
+ from cleanfid.fid import get_folder_features, build_feature_extractor, frechet_distance
18
+ import vision_aided_loss
19
+ from model import make_1step_sched
20
+ from cyclegan_turbo import CycleGAN_Turbo, VAE_encode, VAE_decode, initialize_unet, initialize_vae
21
+ from my_utils.training_utils import UnpairedDataset, build_transform, parse_args_unpaired_training
22
+ from my_utils.dino_struct import DinoStructureLoss
23
+
24
+
25
+ def main(args):
26
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with=args.report_to)
27
+ set_seed(args.seed)
28
+
29
+ if accelerator.is_main_process:
30
+ os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer", revision=args.revision, use_fast=False,)
33
+ noise_scheduler_1step = make_1step_sched()
34
+ text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
35
+
36
+ unet, l_modules_unet_encoder, l_modules_unet_decoder, l_modules_unet_others = initialize_unet(args.lora_rank_unet, return_lora_module_names=True)
37
+ vae_a2b, vae_lora_target_modules = initialize_vae(args.lora_rank_vae, return_lora_module_names=True)
38
+
39
+ weight_dtype = torch.float32
40
+ vae_a2b.to(accelerator.device, dtype=weight_dtype)
41
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
42
+ unet.to(accelerator.device, dtype=weight_dtype)
43
+ text_encoder.requires_grad_(False)
44
+
45
+ if args.gan_disc_type == "vagan_clip":
46
+ net_disc_a = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device="cuda")
47
+ net_disc_a.cv_ensemble.requires_grad_(False) # Freeze feature extractor
48
+ net_disc_b = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device="cuda")
49
+ net_disc_b.cv_ensemble.requires_grad_(False) # Freeze feature extractor
50
+
51
+ crit_cycle, crit_idt = torch.nn.L1Loss(), torch.nn.L1Loss()
52
+
53
+ if args.enable_xformers_memory_efficient_attention:
54
+ unet.enable_xformers_memory_efficient_attention()
55
+
56
+ if args.gradient_checkpointing:
57
+ unet.enable_gradient_checkpointing()
58
+
59
+ if args.allow_tf32:
60
+ torch.backends.cuda.matmul.allow_tf32 = True
61
+
62
+ unet.conv_in.requires_grad_(True)
63
+ vae_b2a = copy.deepcopy(vae_a2b)
64
+ params_gen = CycleGAN_Turbo.get_traininable_params(unet, vae_a2b, vae_b2a)
65
+
66
+ vae_enc = VAE_encode(vae_a2b, vae_b2a=vae_b2a)
67
+ vae_dec = VAE_decode(vae_a2b, vae_b2a=vae_b2a)
68
+
69
+ optimizer_gen = torch.optim.AdamW(params_gen, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2),
70
+ weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,)
71
+
72
+ params_disc = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
73
+ optimizer_disc = torch.optim.AdamW(params_disc, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2),
74
+ weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,)
75
+
76
+ dataset_train = UnpairedDataset(dataset_folder=args.dataset_folder, image_prep=args.train_img_prep, split="train", tokenizer=tokenizer)
77
+ train_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
78
+ T_val = build_transform(args.val_img_prep)
79
+ fixed_caption_src = dataset_train.fixed_caption_src
80
+ fixed_caption_tgt = dataset_train.fixed_caption_tgt
81
+ l_images_src_test = []
82
+ for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp"]:
83
+ l_images_src_test.extend(glob(os.path.join(args.dataset_folder, "test_A", ext)))
84
+ l_images_tgt_test = []
85
+ for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp"]:
86
+ l_images_tgt_test.extend(glob(os.path.join(args.dataset_folder, "test_B", ext)))
87
+ l_images_src_test, l_images_tgt_test = sorted(l_images_src_test), sorted(l_images_tgt_test)
88
+
89
+ # make the reference FID statistics
90
+ if accelerator.is_main_process:
91
+ feat_model = build_feature_extractor("clean", "cuda", use_dataparallel=False)
92
+ """
93
+ FID reference statistics for A -> B translation
94
+ """
95
+ output_dir_ref = os.path.join(args.output_dir, "fid_reference_a2b")
96
+ os.makedirs(output_dir_ref, exist_ok=True)
97
+ # transform all images according to the validation transform and save them
98
+ for _path in tqdm(l_images_tgt_test):
99
+ _img = T_val(Image.open(_path).convert("RGB"))
100
+ outf = os.path.join(output_dir_ref, os.path.basename(_path)).replace(".jpg", ".png")
101
+ if not os.path.exists(outf):
102
+ _img.save(outf)
103
+ # compute the features for the reference images
104
+ ref_features = get_folder_features(output_dir_ref, model=feat_model, num_workers=0, num=None,
105
+ shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
106
+ mode="clean", custom_fn_resize=None, description="", verbose=True,
107
+ custom_image_tranform=None)
108
+ a2b_ref_mu, a2b_ref_sigma = np.mean(ref_features, axis=0), np.cov(ref_features, rowvar=False)
109
+ """
110
+ FID reference statistics for B -> A translation
111
+ """
112
+ # transform all images according to the validation transform and save them
113
+ output_dir_ref = os.path.join(args.output_dir, "fid_reference_b2a")
114
+ os.makedirs(output_dir_ref, exist_ok=True)
115
+ for _path in tqdm(l_images_src_test):
116
+ _img = T_val(Image.open(_path).convert("RGB"))
117
+ outf = os.path.join(output_dir_ref, os.path.basename(_path)).replace(".jpg", ".png")
118
+ if not os.path.exists(outf):
119
+ _img.save(outf)
120
+ # compute the features for the reference images
121
+ ref_features = get_folder_features(output_dir_ref, model=feat_model, num_workers=0, num=None,
122
+ shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
123
+ mode="clean", custom_fn_resize=None, description="", verbose=True,
124
+ custom_image_tranform=None)
125
+ b2a_ref_mu, b2a_ref_sigma = np.mean(ref_features, axis=0), np.cov(ref_features, rowvar=False)
126
+
127
+ lr_scheduler_gen = get_scheduler(args.lr_scheduler, optimizer=optimizer_gen,
128
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
129
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
130
+ num_cycles=args.lr_num_cycles, power=args.lr_power)
131
+ lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc,
132
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
133
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
134
+ num_cycles=args.lr_num_cycles, power=args.lr_power)
135
+
136
+ net_lpips = lpips.LPIPS(net='vgg')
137
+ net_lpips.cuda()
138
+ net_lpips.requires_grad_(False)
139
+
140
+ fixed_a2b_tokens = tokenizer(fixed_caption_tgt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
141
+ fixed_a2b_emb_base = text_encoder(fixed_a2b_tokens.cuda().unsqueeze(0))[0].detach()
142
+ fixed_b2a_tokens = tokenizer(fixed_caption_src, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
143
+ fixed_b2a_emb_base = text_encoder(fixed_b2a_tokens.cuda().unsqueeze(0))[0].detach()
144
+ del text_encoder, tokenizer # free up some memory
145
+
146
+ unet, vae_enc, vae_dec, net_disc_a, net_disc_b = accelerator.prepare(unet, vae_enc, vae_dec, net_disc_a, net_disc_b)
147
+ net_lpips, optimizer_gen, optimizer_disc, train_dataloader, lr_scheduler_gen, lr_scheduler_disc = accelerator.prepare(
148
+ net_lpips, optimizer_gen, optimizer_disc, train_dataloader, lr_scheduler_gen, lr_scheduler_disc
149
+ )
150
+ if accelerator.is_main_process:
151
+ accelerator.init_trackers(args.tracker_project_name, config=dict(vars(args)))
152
+
153
+ first_epoch = 0
154
+ global_step = 0
155
+ progress_bar = tqdm(range(0, args.max_train_steps), initial=global_step, desc="Steps",
156
+ disable=not accelerator.is_local_main_process,)
157
+ # turn off eff. attn for the disc
158
+ for name, module in net_disc_a.named_modules():
159
+ if "attn" in name:
160
+ module.fused_attn = False
161
+ for name, module in net_disc_b.named_modules():
162
+ if "attn" in name:
163
+ module.fused_attn = False
164
+
165
+ for epoch in range(first_epoch, args.max_train_epochs):
166
+ for step, batch in enumerate(train_dataloader):
167
+ l_acc = [unet, net_disc_a, net_disc_b, vae_enc, vae_dec]
168
+ with accelerator.accumulate(*l_acc):
169
+ img_a = batch["pixel_values_src"].to(dtype=weight_dtype)
170
+ img_b = batch["pixel_values_tgt"].to(dtype=weight_dtype)
171
+
172
+ bsz = img_a.shape[0]
173
+ fixed_a2b_emb = fixed_a2b_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)
174
+ fixed_b2a_emb = fixed_b2a_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)
175
+ timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * bsz, device=img_a.device).long()
176
+
177
+ """
178
+ Cycle Objective
179
+ """
180
+ # A -> fake B -> rec A
181
+ cyc_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
182
+ cyc_rec_a = CycleGAN_Turbo.forward_with_networks(cyc_fake_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
183
+ loss_cycle_a = crit_cycle(cyc_rec_a, img_a) * args.lambda_cycle
184
+ loss_cycle_a += net_lpips(cyc_rec_a, img_a).mean() * args.lambda_cycle_lpips
185
+ # B -> fake A -> rec B
186
+ cyc_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
187
+ cyc_rec_b = CycleGAN_Turbo.forward_with_networks(cyc_fake_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
188
+ loss_cycle_b = crit_cycle(cyc_rec_b, img_b) * args.lambda_cycle
189
+ loss_cycle_b += net_lpips(cyc_rec_b, img_b).mean() * args.lambda_cycle_lpips
190
+ accelerator.backward(loss_cycle_a + loss_cycle_b, retain_graph=False)
191
+ if accelerator.sync_gradients:
192
+ accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
193
+
194
+ optimizer_gen.step()
195
+ lr_scheduler_gen.step()
196
+ optimizer_gen.zero_grad()
197
+
198
+ """
199
+ Generator Objective (GAN) for task a->b and b->a (fake inputs)
200
+ """
201
+ fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
202
+ fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
203
+ loss_gan_a = net_disc_a(fake_b, for_G=True).mean() * args.lambda_gan
204
+ loss_gan_b = net_disc_b(fake_a, for_G=True).mean() * args.lambda_gan
205
+ accelerator.backward(loss_gan_a + loss_gan_b, retain_graph=False)
206
+ if accelerator.sync_gradients:
207
+ accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
208
+ optimizer_gen.step()
209
+ lr_scheduler_gen.step()
210
+ optimizer_gen.zero_grad()
211
+
212
+ """
213
+ Identity Objective
214
+ """
215
+ idt_a = CycleGAN_Turbo.forward_with_networks(img_b, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
216
+ loss_idt_a = crit_idt(idt_a, img_b) * args.lambda_idt
217
+ loss_idt_a += net_lpips(idt_a, img_b).mean() * args.lambda_idt_lpips
218
+ idt_b = CycleGAN_Turbo.forward_with_networks(img_a, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
219
+ loss_idt_b = crit_idt(idt_b, img_a) * args.lambda_idt
220
+ loss_idt_b += net_lpips(idt_b, img_a).mean() * args.lambda_idt_lpips
221
+ loss_g_idt = loss_idt_a + loss_idt_b
222
+ accelerator.backward(loss_g_idt, retain_graph=False)
223
+ if accelerator.sync_gradients:
224
+ accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
225
+ optimizer_gen.step()
226
+ lr_scheduler_gen.step()
227
+ optimizer_gen.zero_grad()
228
+
229
+ """
230
+ Discriminator for task a->b and b->a (fake inputs)
231
+ """
232
+ loss_D_A_fake = net_disc_a(fake_b.detach(), for_real=False).mean() * args.lambda_gan
233
+ loss_D_B_fake = net_disc_b(fake_a.detach(), for_real=False).mean() * args.lambda_gan
234
+ loss_D_fake = (loss_D_A_fake + loss_D_B_fake) * 0.5
235
+ accelerator.backward(loss_D_fake, retain_graph=False)
236
+ if accelerator.sync_gradients:
237
+ params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
238
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
239
+ optimizer_disc.step()
240
+ lr_scheduler_disc.step()
241
+ optimizer_disc.zero_grad()
242
+
243
+ """
244
+ Discriminator for task a->b and b->a (real inputs)
245
+ """
246
+ loss_D_A_real = net_disc_a(img_b, for_real=True).mean() * args.lambda_gan
247
+ loss_D_B_real = net_disc_b(img_a, for_real=True).mean() * args.lambda_gan
248
+ loss_D_real = (loss_D_A_real + loss_D_B_real) * 0.5
249
+ accelerator.backward(loss_D_real, retain_graph=False)
250
+ if accelerator.sync_gradients:
251
+ params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
252
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
253
+ optimizer_disc.step()
254
+ lr_scheduler_disc.step()
255
+ optimizer_disc.zero_grad()
256
+
257
+ logs = {}
258
+ logs["cycle_a"] = loss_cycle_a.detach().item()
259
+ logs["cycle_b"] = loss_cycle_b.detach().item()
260
+ logs["gan_a"] = loss_gan_a.detach().item()
261
+ logs["gan_b"] = loss_gan_b.detach().item()
262
+ logs["disc_a"] = loss_D_A_fake.detach().item() + loss_D_A_real.detach().item()
263
+ logs["disc_b"] = loss_D_B_fake.detach().item() + loss_D_B_real.detach().item()
264
+ logs["idt_a"] = loss_idt_a.detach().item()
265
+ logs["idt_b"] = loss_idt_b.detach().item()
266
+
267
+ if accelerator.sync_gradients:
268
+ progress_bar.update(1)
269
+ global_step += 1
270
+
271
+ if accelerator.is_main_process:
272
+ eval_unet = accelerator.unwrap_model(unet)
273
+ eval_vae_enc = accelerator.unwrap_model(vae_enc)
274
+ eval_vae_dec = accelerator.unwrap_model(vae_dec)
275
+ if global_step % args.viz_freq == 1:
276
+ for tracker in accelerator.trackers:
277
+ if tracker.name == "wandb":
278
+ viz_img_a = batch["pixel_values_src"].to(dtype=weight_dtype)
279
+ viz_img_b = batch["pixel_values_tgt"].to(dtype=weight_dtype)
280
+ log_dict = {
281
+ "train/real_a": [wandb.Image(viz_img_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)],
282
+ "train/real_b": [wandb.Image(viz_img_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)],
283
+ }
284
+ log_dict["train/rec_a"] = [wandb.Image(cyc_rec_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)]
285
+ log_dict["train/rec_b"] = [wandb.Image(cyc_rec_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)]
286
+ log_dict["train/fake_b"] = [wandb.Image(fake_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)]
287
+ log_dict["train/fake_a"] = [wandb.Image(fake_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)]
288
+ tracker.log(log_dict)
289
+ gc.collect()
290
+ torch.cuda.empty_cache()
291
+
292
+ if global_step % args.checkpointing_steps == 1:
293
+ outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
294
+ sd = {}
295
+ sd["l_target_modules_encoder"] = l_modules_unet_encoder
296
+ sd["l_target_modules_decoder"] = l_modules_unet_decoder
297
+ sd["l_modules_others"] = l_modules_unet_others
298
+ sd["rank_unet"] = args.lora_rank_unet
299
+ sd["sd_encoder"] = get_peft_model_state_dict(eval_unet, adapter_name="default_encoder")
300
+ sd["sd_decoder"] = get_peft_model_state_dict(eval_unet, adapter_name="default_decoder")
301
+ sd["sd_other"] = get_peft_model_state_dict(eval_unet, adapter_name="default_others")
302
+ sd["rank_vae"] = args.lora_rank_vae
303
+ sd["vae_lora_target_modules"] = vae_lora_target_modules
304
+ sd["sd_vae_enc"] = eval_vae_enc.state_dict()
305
+ sd["sd_vae_dec"] = eval_vae_dec.state_dict()
306
+ torch.save(sd, outf)
307
+ gc.collect()
308
+ torch.cuda.empty_cache()
309
+
310
+ # compute val FID and DINO-Struct scores
311
+ if global_step % args.validation_steps == 1:
312
+ _timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * 1, device="cuda").long()
313
+ net_dino = DinoStructureLoss()
314
+ """
315
+ Evaluate "A->B"
316
+ """
317
+ fid_output_dir = os.path.join(args.output_dir, f"fid-{global_step}/samples_a2b")
318
+ os.makedirs(fid_output_dir, exist_ok=True)
319
+ l_dino_scores_a2b = []
320
+ # get val input images from domain a
321
+ for idx, input_img_path in enumerate(tqdm(l_images_src_test)):
322
+ if idx > args.validation_num_images and args.validation_num_images > 0:
323
+ break
324
+ outf = os.path.join(fid_output_dir, f"{idx}.png")
325
+ with torch.no_grad():
326
+ input_img = T_val(Image.open(input_img_path).convert("RGB"))
327
+ img_a = transforms.ToTensor()(input_img)
328
+ img_a = transforms.Normalize([0.5], [0.5])(img_a).unsqueeze(0).cuda()
329
+ eval_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", eval_vae_enc, eval_unet,
330
+ eval_vae_dec, noise_scheduler_1step, _timesteps, fixed_a2b_emb[0:1])
331
+ eval_fake_b_pil = transforms.ToPILImage()(eval_fake_b[0] * 0.5 + 0.5)
332
+ eval_fake_b_pil.save(outf)
333
+ a = net_dino.preprocess(input_img).unsqueeze(0).cuda()
334
+ b = net_dino.preprocess(eval_fake_b_pil).unsqueeze(0).cuda()
335
+ dino_ssim = net_dino.calculate_global_ssim_loss(a, b).item()
336
+ l_dino_scores_a2b.append(dino_ssim)
337
+ dino_score_a2b = np.mean(l_dino_scores_a2b)
338
+ gen_features = get_folder_features(fid_output_dir, model=feat_model, num_workers=0, num=None,
339
+ shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
340
+ mode="clean", custom_fn_resize=None, description="", verbose=True,
341
+ custom_image_tranform=None)
342
+ ed_mu, ed_sigma = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)
343
+ score_fid_a2b = frechet_distance(a2b_ref_mu, a2b_ref_sigma, ed_mu, ed_sigma)
344
+ print(f"step={global_step}, fid(a2b)={score_fid_a2b:.2f}, dino(a2b)={dino_score_a2b:.3f}")
345
+
346
+ """
347
+ compute FID for "B->A"
348
+ """
349
+ fid_output_dir = os.path.join(args.output_dir, f"fid-{global_step}/samples_b2a")
350
+ os.makedirs(fid_output_dir, exist_ok=True)
351
+ l_dino_scores_b2a = []
352
+ # get val input images from domain b
353
+ for idx, input_img_path in enumerate(tqdm(l_images_tgt_test)):
354
+ if idx > args.validation_num_images and args.validation_num_images > 0:
355
+ break
356
+ outf = os.path.join(fid_output_dir, f"{idx}.png")
357
+ with torch.no_grad():
358
+ input_img = T_val(Image.open(input_img_path).convert("RGB"))
359
+ img_b = transforms.ToTensor()(input_img)
360
+ img_b = transforms.Normalize([0.5], [0.5])(img_b).unsqueeze(0).cuda()
361
+ eval_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", eval_vae_enc, eval_unet,
362
+ eval_vae_dec, noise_scheduler_1step, _timesteps, fixed_b2a_emb[0:1])
363
+ eval_fake_a_pil = transforms.ToPILImage()(eval_fake_a[0] * 0.5 + 0.5)
364
+ eval_fake_a_pil.save(outf)
365
+ a = net_dino.preprocess(input_img).unsqueeze(0).cuda()
366
+ b = net_dino.preprocess(eval_fake_a_pil).unsqueeze(0).cuda()
367
+ dino_ssim = net_dino.calculate_global_ssim_loss(a, b).item()
368
+ l_dino_scores_b2a.append(dino_ssim)
369
+ dino_score_b2a = np.mean(l_dino_scores_b2a)
370
+ gen_features = get_folder_features(fid_output_dir, model=feat_model, num_workers=0, num=None,
371
+ shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
372
+ mode="clean", custom_fn_resize=None, description="", verbose=True,
373
+ custom_image_tranform=None)
374
+ ed_mu, ed_sigma = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)
375
+ score_fid_b2a = frechet_distance(b2a_ref_mu, b2a_ref_sigma, ed_mu, ed_sigma)
376
+ print(f"step={global_step}, fid(b2a)={score_fid_b2a}, dino(b2a)={dino_score_b2a:.3f}")
377
+ logs["val/fid_a2b"], logs["val/fid_b2a"] = score_fid_a2b, score_fid_b2a
378
+ logs["val/dino_struct_a2b"], logs["val/dino_struct_b2a"] = dino_score_a2b, dino_score_b2a
379
+ del net_dino # free up memory
380
+
381
+ progress_bar.set_postfix(**logs)
382
+ accelerator.log(logs, step=global_step)
383
+ if global_step >= args.max_train_steps:
384
+ break
385
+
386
+
387
+ if __name__ == "__main__":
388
+ args = parse_args_unpaired_training()
389
+ main(args)
src/train_pix2pix_turbo.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import lpips
4
+ import clip
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ import transformers
10
+ from accelerate import Accelerator
11
+ from accelerate.utils import set_seed
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from tqdm.auto import tqdm
15
+
16
+ import diffusers
17
+ from diffusers.utils.import_utils import is_xformers_available
18
+ from diffusers.optimization import get_scheduler
19
+
20
+ import wandb
21
+ from cleanfid.fid import get_folder_features, build_feature_extractor, fid_from_feats
22
+
23
+ from pix2pix_turbo import Pix2Pix_Turbo
24
+ from my_utils.training_utils import parse_args_paired_training, PairedDataset
25
+
26
+
27
+ def main(args):
28
+ accelerator = Accelerator(
29
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
30
+ mixed_precision=args.mixed_precision,
31
+ log_with=args.report_to,
32
+ )
33
+
34
+ if accelerator.is_local_main_process:
35
+ transformers.utils.logging.set_verbosity_warning()
36
+ diffusers.utils.logging.set_verbosity_info()
37
+ else:
38
+ transformers.utils.logging.set_verbosity_error()
39
+ diffusers.utils.logging.set_verbosity_error()
40
+
41
+ if args.seed is not None:
42
+ set_seed(args.seed)
43
+
44
+ if accelerator.is_main_process:
45
+ os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
46
+ os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
47
+
48
+ if args.pretrained_model_name_or_path == "stabilityai/sd-turbo":
49
+ net_pix2pix = Pix2Pix_Turbo(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae)
50
+ net_pix2pix.set_train()
51
+
52
+ if args.enable_xformers_memory_efficient_attention:
53
+ if is_xformers_available():
54
+ net_pix2pix.unet.enable_xformers_memory_efficient_attention()
55
+ else:
56
+ raise ValueError("xformers is not available, please install it by running `pip install xformers`")
57
+
58
+ if args.gradient_checkpointing:
59
+ net_pix2pix.unet.enable_gradient_checkpointing()
60
+
61
+ if args.allow_tf32:
62
+ torch.backends.cuda.matmul.allow_tf32 = True
63
+
64
+ if args.gan_disc_type == "vagan_clip":
65
+ import vision_aided_loss
66
+ net_disc = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device="cuda")
67
+ else:
68
+ raise NotImplementedError(f"Discriminator type {args.gan_disc_type} not implemented")
69
+
70
+ net_disc = net_disc.cuda()
71
+ net_disc.requires_grad_(True)
72
+ net_disc.cv_ensemble.requires_grad_(False)
73
+ net_disc.train()
74
+
75
+ net_lpips = lpips.LPIPS(net='vgg').cuda()
76
+ net_clip, _ = clip.load("ViT-B/32", device="cuda")
77
+ net_clip.requires_grad_(False)
78
+ net_clip.eval()
79
+
80
+ net_lpips.requires_grad_(False)
81
+
82
+ # make the optimizer
83
+ layers_to_opt = []
84
+ for n, _p in net_pix2pix.unet.named_parameters():
85
+ if "lora" in n:
86
+ assert _p.requires_grad
87
+ layers_to_opt.append(_p)
88
+ layers_to_opt += list(net_pix2pix.unet.conv_in.parameters())
89
+ for n, _p in net_pix2pix.vae.named_parameters():
90
+ if "lora" in n and "vae_skip" in n:
91
+ assert _p.requires_grad
92
+ layers_to_opt.append(_p)
93
+ layers_to_opt = layers_to_opt + list(net_pix2pix.vae.decoder.skip_conv_1.parameters()) + \
94
+ list(net_pix2pix.vae.decoder.skip_conv_2.parameters()) + \
95
+ list(net_pix2pix.vae.decoder.skip_conv_3.parameters()) + \
96
+ list(net_pix2pix.vae.decoder.skip_conv_4.parameters())
97
+
98
+ optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
99
+ betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
100
+ eps=args.adam_epsilon,)
101
+ lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
102
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
103
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
104
+ num_cycles=args.lr_num_cycles, power=args.lr_power,)
105
+
106
+ optimizer_disc = torch.optim.AdamW(net_disc.parameters(), lr=args.learning_rate,
107
+ betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
108
+ eps=args.adam_epsilon,)
109
+ lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc,
110
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
111
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
112
+ num_cycles=args.lr_num_cycles, power=args.lr_power)
113
+
114
+ dataset_train = PairedDataset(dataset_folder=args.dataset_folder, image_prep=args.train_image_prep, split="train", tokenizer=net_pix2pix.tokenizer)
115
+ dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
116
+ dataset_val = PairedDataset(dataset_folder=args.dataset_folder, image_prep=args.test_image_prep, split="test", tokenizer=net_pix2pix.tokenizer)
117
+ dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
118
+
119
+ # Prepare everything with our `accelerator`.
120
+ net_pix2pix, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc = accelerator.prepare(
121
+ net_pix2pix, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc
122
+ )
123
+ net_clip, net_lpips = accelerator.prepare(net_clip, net_lpips)
124
+ # renorm with image net statistics
125
+ t_clip_renorm = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
126
+ weight_dtype = torch.float32
127
+ if accelerator.mixed_precision == "fp16":
128
+ weight_dtype = torch.float16
129
+ elif accelerator.mixed_precision == "bf16":
130
+ weight_dtype = torch.bfloat16
131
+
132
+ # Move al networksr to device and cast to weight_dtype
133
+ net_pix2pix.to(accelerator.device, dtype=weight_dtype)
134
+ net_disc.to(accelerator.device, dtype=weight_dtype)
135
+ net_lpips.to(accelerator.device, dtype=weight_dtype)
136
+ net_clip.to(accelerator.device, dtype=weight_dtype)
137
+
138
+ # We need to initialize the trackers we use, and also store our configuration.
139
+ # The trackers initializes automatically on the main process.
140
+ if accelerator.is_main_process:
141
+ tracker_config = dict(vars(args))
142
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
143
+
144
+ progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
145
+ disable=not accelerator.is_local_main_process,)
146
+
147
+ # turn off eff. attn for the discriminator
148
+ for name, module in net_disc.named_modules():
149
+ if "attn" in name:
150
+ module.fused_attn = False
151
+
152
+ # compute the reference stats for FID tracking
153
+ if accelerator.is_main_process and args.track_val_fid:
154
+ feat_model = build_feature_extractor("clean", "cuda", use_dataparallel=False)
155
+
156
+ def fn_transform(x):
157
+ x_pil = Image.fromarray(x)
158
+ out_pil = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS)(x_pil)
159
+ return np.array(out_pil)
160
+
161
+ ref_stats = get_folder_features(os.path.join(args.dataset_folder, "test_B"), model=feat_model, num_workers=0, num=None,
162
+ shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
163
+ mode="clean", custom_image_tranform=fn_transform, description="", verbose=True)
164
+
165
+ # start the training loop
166
+ global_step = 0
167
+ for epoch in range(0, args.num_training_epochs):
168
+ for step, batch in enumerate(dl_train):
169
+ l_acc = [net_pix2pix, net_disc]
170
+ with accelerator.accumulate(*l_acc):
171
+ x_src = batch["conditioning_pixel_values"]
172
+ x_tgt = batch["output_pixel_values"]
173
+ B, C, H, W = x_src.shape
174
+ # forward pass
175
+ x_tgt_pred = net_pix2pix(x_src, prompt_tokens=batch["input_ids"], deterministic=True)
176
+ # Reconstruction loss
177
+ loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2
178
+ loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips
179
+ loss = loss_l2 + loss_lpips
180
+ # CLIP similarity loss
181
+ if args.lambda_clipsim > 0:
182
+ x_tgt_pred_renorm = t_clip_renorm(x_tgt_pred * 0.5 + 0.5)
183
+ x_tgt_pred_renorm = F.interpolate(x_tgt_pred_renorm, (224, 224), mode="bilinear", align_corners=False)
184
+ caption_tokens = clip.tokenize(batch["caption"], truncate=True).to(x_tgt_pred.device)
185
+ clipsim, _ = net_clip(x_tgt_pred_renorm, caption_tokens)
186
+ loss_clipsim = (1 - clipsim.mean() / 100)
187
+ loss += loss_clipsim * args.lambda_clipsim
188
+ accelerator.backward(loss, retain_graph=False)
189
+ if accelerator.sync_gradients:
190
+ accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
191
+ optimizer.step()
192
+ lr_scheduler.step()
193
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
194
+
195
+ """
196
+ Generator loss: fool the discriminator
197
+ """
198
+ x_tgt_pred = net_pix2pix(x_src, prompt_tokens=batch["input_ids"], deterministic=True)
199
+ lossG = net_disc(x_tgt_pred, for_G=True).mean() * args.lambda_gan
200
+ accelerator.backward(lossG)
201
+ if accelerator.sync_gradients:
202
+ accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
203
+ optimizer.step()
204
+ lr_scheduler.step()
205
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
206
+
207
+ """
208
+ Discriminator loss: fake image vs real image
209
+ """
210
+ # real image
211
+ lossD_real = net_disc(x_tgt.detach(), for_real=True).mean() * args.lambda_gan
212
+ accelerator.backward(lossD_real.mean())
213
+ if accelerator.sync_gradients:
214
+ accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
215
+ optimizer_disc.step()
216
+ lr_scheduler_disc.step()
217
+ optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
218
+ # fake image
219
+ lossD_fake = net_disc(x_tgt_pred.detach(), for_real=False).mean() * args.lambda_gan
220
+ accelerator.backward(lossD_fake.mean())
221
+ if accelerator.sync_gradients:
222
+ accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
223
+ optimizer_disc.step()
224
+ optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
225
+ lossD = lossD_real + lossD_fake
226
+
227
+ # Checks if the accelerator has performed an optimization step behind the scenes
228
+ if accelerator.sync_gradients:
229
+ progress_bar.update(1)
230
+ global_step += 1
231
+
232
+ if accelerator.is_main_process:
233
+ logs = {}
234
+ # log all the losses
235
+ logs["lossG"] = lossG.detach().item()
236
+ logs["lossD"] = lossD.detach().item()
237
+ logs["loss_l2"] = loss_l2.detach().item()
238
+ logs["loss_lpips"] = loss_lpips.detach().item()
239
+ if args.lambda_clipsim > 0:
240
+ logs["loss_clipsim"] = loss_clipsim.detach().item()
241
+ progress_bar.set_postfix(**logs)
242
+
243
+ # viz some images
244
+ if global_step % args.viz_freq == 1:
245
+ log_dict = {
246
+ "train/source": [wandb.Image(x_src[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
247
+ "train/target": [wandb.Image(x_tgt[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
248
+ "train/model_output": [wandb.Image(x_tgt_pred[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
249
+ }
250
+ for k in log_dict:
251
+ logs[k] = log_dict[k]
252
+
253
+ # checkpoint the model
254
+ if global_step % args.checkpointing_steps == 1:
255
+ outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
256
+ accelerator.unwrap_model(net_pix2pix).save_model(outf)
257
+
258
+ # compute validation set FID, L2, LPIPS, CLIP-SIM
259
+ if global_step % args.eval_freq == 1:
260
+ l_l2, l_lpips, l_clipsim = [], [], []
261
+ if args.track_val_fid:
262
+ os.makedirs(os.path.join(args.output_dir, "eval", f"fid_{global_step}"), exist_ok=True)
263
+ for step, batch_val in enumerate(dl_val):
264
+ if step >= args.num_samples_eval:
265
+ break
266
+ x_src = batch_val["conditioning_pixel_values"].cuda()
267
+ x_tgt = batch_val["output_pixel_values"].cuda()
268
+ B, C, H, W = x_src.shape
269
+ assert B == 1, "Use batch size 1 for eval."
270
+ with torch.no_grad():
271
+ # forward pass
272
+ x_tgt_pred = accelerator.unwrap_model(net_pix2pix)(x_src, prompt_tokens=batch_val["input_ids"].cuda(), deterministic=True)
273
+ # compute the reconstruction losses
274
+ loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean")
275
+ loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean()
276
+ # compute clip similarity loss
277
+ x_tgt_pred_renorm = t_clip_renorm(x_tgt_pred * 0.5 + 0.5)
278
+ x_tgt_pred_renorm = F.interpolate(x_tgt_pred_renorm, (224, 224), mode="bilinear", align_corners=False)
279
+ caption_tokens = clip.tokenize(batch_val["caption"], truncate=True).to(x_tgt_pred.device)
280
+ clipsim, _ = net_clip(x_tgt_pred_renorm, caption_tokens)
281
+ clipsim = clipsim.mean()
282
+
283
+ l_l2.append(loss_l2.item())
284
+ l_lpips.append(loss_lpips.item())
285
+ l_clipsim.append(clipsim.item())
286
+ # save output images to file for FID evaluation
287
+ if args.track_val_fid:
288
+ output_pil = transforms.ToPILImage()(x_tgt_pred[0].cpu() * 0.5 + 0.5)
289
+ outf = os.path.join(args.output_dir, "eval", f"fid_{global_step}", f"val_{step}.png")
290
+ output_pil.save(outf)
291
+ if args.track_val_fid:
292
+ curr_stats = get_folder_features(os.path.join(args.output_dir, "eval", f"fid_{global_step}"), model=feat_model, num_workers=0, num=None,
293
+ shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
294
+ mode="clean", custom_image_tranform=fn_transform, description="", verbose=True)
295
+ fid_score = fid_from_feats(ref_stats, curr_stats)
296
+ logs["val/clean_fid"] = fid_score
297
+ logs["val/l2"] = np.mean(l_l2)
298
+ logs["val/lpips"] = np.mean(l_lpips)
299
+ logs["val/clipsim"] = np.mean(l_clipsim)
300
+ gc.collect()
301
+ torch.cuda.empty_cache()
302
+ accelerator.log(logs, step=global_step)
303
+
304
+
305
+ if __name__ == "__main__":
306
+ args = parse_args_paired_training()
307
+ main(args)