MultiMatrix commited on
Commit
a4ea705
·
verified ·
1 Parent(s): 4bef05b

Upload train_stage2.py

Browse files
Files changed (1) hide show
  1. train_stage2.py +207 -0
train_stage2.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import ArgumentParser
3
+
4
+ from omegaconf import OmegaConf
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from torchvision.utils import make_grid
8
+ from accelerate import Accelerator
9
+ from accelerate.utils import set_seed
10
+ from einops import rearrange
11
+ from tqdm import tqdm
12
+ from torch.utils.tensorboard import SummaryWriter
13
+ from PIL import Image, ImageDraw, ImageFont
14
+ import numpy as np
15
+
16
+ from model import ControlLDM, SwinIR, Diffusion
17
+ from utils.common import instantiate_from_config
18
+ from utils.sampler import SpacedSampler
19
+
20
+
21
+ def log_txt_as_img(wh, xc):
22
+ # wh a tuple of (width, height)
23
+ # xc a list of captions to plot
24
+ b = len(xc)
25
+ txts = list()
26
+ for bi in range(b):
27
+ txt = Image.new("RGB", wh, color="white")
28
+ draw = ImageDraw.Draw(txt)
29
+ # font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
30
+ font = ImageFont.load_default()
31
+ nc = int(40 * (wh[0] / 256))
32
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
33
+
34
+ try:
35
+ draw.text((0, 0), lines, fill="black", font=font)
36
+ except UnicodeEncodeError:
37
+ print("Cant encode string for logging. Skipping.")
38
+
39
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
40
+ txts.append(txt)
41
+ txts = np.stack(txts)
42
+ txts = torch.tensor(txts)
43
+ return txts
44
+
45
+
46
+ def main(args) -> None:
47
+ # Setup accelerator:
48
+ accelerator = Accelerator(split_batches=True)
49
+ set_seed(231)
50
+ device = accelerator.device
51
+ cfg = OmegaConf.load(args.config)
52
+
53
+ # Setup an experiment folder:
54
+ if accelerator.is_local_main_process:
55
+ exp_dir = cfg.train.exp_dir
56
+ os.makedirs(exp_dir, exist_ok=True)
57
+ ckpt_dir = os.path.join(exp_dir, "checkpoints")
58
+ os.makedirs(ckpt_dir, exist_ok=True)
59
+ print(f"Experiment directory created at {exp_dir}")
60
+
61
+ # Create model:
62
+ cldm: ControlLDM = instantiate_from_config(cfg.model.cldm)
63
+ sd = torch.load(cfg.train.sd_path, map_location="cpu")["state_dict"]
64
+ unused = cldm.load_pretrained_sd(sd)
65
+ if accelerator.is_local_main_process:
66
+ print(f"strictly load pretrained SD weight from {cfg.train.sd_path}\n"
67
+ f"unused weights: {unused}")
68
+
69
+ if cfg.train.resume:
70
+ cldm.load_controlnet_from_ckpt(torch.load(cfg.train.resume, map_location="cpu"))
71
+ if accelerator.is_local_main_process:
72
+ print(f"strictly load controlnet weight from checkpoint: {cfg.train.resume}")
73
+ else:
74
+ init_with_new_zero, init_with_scratch = cldm.load_controlnet_from_unet()
75
+ if accelerator.is_local_main_process:
76
+ print(f"strictly load controlnet weight from pretrained SD\n"
77
+ f"weights initialized with newly added zeros: {init_with_new_zero}\n"
78
+ f"weights initialized from scratch: {init_with_scratch}")
79
+
80
+ swinir: SwinIR = instantiate_from_config(cfg.model.swinir)
81
+ sd = {
82
+ (k[len("module."):] if k.startswith("module.") else k): v
83
+ for k, v in torch.load(cfg.train.swinir_path, map_location="cpu").items()
84
+ }
85
+ swinir.load_state_dict(sd, strict=True)
86
+ for p in swinir.parameters():
87
+ p.requires_grad = False
88
+ if accelerator.is_local_main_process:
89
+ print(f"load SwinIR from {cfg.train.swinir_path}")
90
+
91
+ diffusion: Diffusion = instantiate_from_config(cfg.model.diffusion)
92
+
93
+ # Setup optimizer:
94
+ opt = torch.optim.AdamW(cldm.controlnet.parameters(), lr=cfg.train.learning_rate)
95
+
96
+ # Setup data:
97
+ dataset = instantiate_from_config(cfg.dataset.train)
98
+ loader = DataLoader(
99
+ dataset=dataset, batch_size=cfg.train.batch_size,
100
+ num_workers=cfg.train.num_workers,
101
+ shuffle=True, drop_last=True
102
+ )
103
+ if accelerator.is_local_main_process:
104
+ print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}")
105
+
106
+ # Prepare models for training:
107
+ cldm.train().to(device)
108
+ swinir.eval().to(device)
109
+ diffusion.to(device)
110
+ cldm, opt, loader = accelerator.prepare(cldm, opt, loader)
111
+ pure_cldm: ControlLDM = accelerator.unwrap_model(cldm)
112
+
113
+ # Variables for monitoring/logging purposes:
114
+ global_step = 0
115
+ max_steps = cfg.train.train_steps
116
+ step_loss = []
117
+ epoch = 0
118
+ epoch_loss = []
119
+ sampler = SpacedSampler(diffusion.betas)
120
+ if accelerator.is_local_main_process:
121
+ writer = SummaryWriter(exp_dir)
122
+ print(f"Training for {max_steps} steps...")
123
+
124
+ while global_step < max_steps:
125
+ pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", total=len(loader))
126
+ for gt, lq, prompt in loader:
127
+ gt = rearrange(gt, "b h w c -> b c h w").contiguous().float().to(device)
128
+ lq = rearrange(lq, "b h w c -> b c h w").contiguous().float().to(device)
129
+ with torch.no_grad():
130
+ z_0 = pure_cldm.vae_encode(gt)
131
+ clean = swinir(lq)
132
+ cond = pure_cldm.prepare_condition(clean, prompt)
133
+ t = torch.randint(0, diffusion.num_timesteps, (z_0.shape[0],), device=device)
134
+
135
+ loss = diffusion.p_losses(cldm, z_0, t, cond)
136
+ opt.zero_grad()
137
+ accelerator.backward(loss)
138
+ opt.step()
139
+
140
+ accelerator.wait_for_everyone()
141
+
142
+ global_step += 1
143
+ step_loss.append(loss.item())
144
+ epoch_loss.append(loss.item())
145
+ pbar.update(1)
146
+ pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}")
147
+
148
+ # Log loss values:
149
+ if global_step % cfg.train.log_every == 0 and global_step > 0:
150
+ # Gather values from all processes
151
+ avg_loss = accelerator.gather(torch.tensor(step_loss, device=device).unsqueeze(0)).mean().item()
152
+ step_loss.clear()
153
+ if accelerator.is_local_main_process:
154
+ writer.add_scalar("loss/loss_simple_step", avg_loss, global_step)
155
+
156
+ # Save checkpoint:
157
+ if global_step % cfg.train.ckpt_every == 0 and global_step > 0:
158
+ if accelerator.is_local_main_process:
159
+ checkpoint = pure_cldm.controlnet.state_dict()
160
+ ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt"
161
+ torch.save(checkpoint, ckpt_path)
162
+
163
+ if global_step % cfg.train.image_every == 0 or global_step == 1:
164
+ N = 12
165
+ log_clean = clean[:N]
166
+ log_cond = {k:v[:N] for k, v in cond.items()}
167
+ log_gt, log_lq = gt[:N], lq[:N]
168
+ log_prompt = prompt[:N]
169
+ cldm.eval()
170
+ with torch.no_grad():
171
+ z = sampler.sample(
172
+ model=cldm, device=device, steps=50, batch_size=len(log_gt), x_size=z_0.shape[1:],
173
+ cond=log_cond, uncond=None, cfg_scale=1.0, x_T=None,
174
+ progress=accelerator.is_local_main_process, progress_leave=False
175
+ )
176
+ if accelerator.is_local_main_process:
177
+ for tag, image in [
178
+ ("image/samples", (pure_cldm.vae_decode(z) + 1) / 2),
179
+ ("image/gt", (log_gt + 1) / 2),
180
+ ("image/lq", log_lq),
181
+ ("image/condition", log_clean),
182
+ ("image/condition_decoded", (pure_cldm.vae_decode(log_cond["c_img"]) + 1) / 2),
183
+ ("image/prompt", (log_txt_as_img((512, 512), log_prompt) + 1) / 2)
184
+ ]:
185
+ writer.add_image(tag, make_grid(image, nrow=4), global_step)
186
+ cldm.train()
187
+ accelerator.wait_for_everyone()
188
+ if global_step == max_steps:
189
+ break
190
+
191
+ pbar.close()
192
+ epoch += 1
193
+ avg_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device).unsqueeze(0)).mean().item()
194
+ epoch_loss.clear()
195
+ if accelerator.is_local_main_process:
196
+ writer.add_scalar("loss/loss_simple_epoch", avg_epoch_loss, global_step)
197
+
198
+ if accelerator.is_local_main_process:
199
+ print("done!")
200
+ writer.close()
201
+
202
+
203
+ if __name__ == "__main__":
204
+ parser = ArgumentParser()
205
+ parser.add_argument("--config", type=str, required=True)
206
+ args = parser.parse_args()
207
+ main(args)