Upload train_stage2.py
Browse files- train_stage2.py +207 -0
train_stage2.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torchvision.utils import make_grid
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from accelerate.utils import set_seed
|
10 |
+
from einops import rearrange
|
11 |
+
from tqdm import tqdm
|
12 |
+
from torch.utils.tensorboard import SummaryWriter
|
13 |
+
from PIL import Image, ImageDraw, ImageFont
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from model import ControlLDM, SwinIR, Diffusion
|
17 |
+
from utils.common import instantiate_from_config
|
18 |
+
from utils.sampler import SpacedSampler
|
19 |
+
|
20 |
+
|
21 |
+
def log_txt_as_img(wh, xc):
|
22 |
+
# wh a tuple of (width, height)
|
23 |
+
# xc a list of captions to plot
|
24 |
+
b = len(xc)
|
25 |
+
txts = list()
|
26 |
+
for bi in range(b):
|
27 |
+
txt = Image.new("RGB", wh, color="white")
|
28 |
+
draw = ImageDraw.Draw(txt)
|
29 |
+
# font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
|
30 |
+
font = ImageFont.load_default()
|
31 |
+
nc = int(40 * (wh[0] / 256))
|
32 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
33 |
+
|
34 |
+
try:
|
35 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
36 |
+
except UnicodeEncodeError:
|
37 |
+
print("Cant encode string for logging. Skipping.")
|
38 |
+
|
39 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
40 |
+
txts.append(txt)
|
41 |
+
txts = np.stack(txts)
|
42 |
+
txts = torch.tensor(txts)
|
43 |
+
return txts
|
44 |
+
|
45 |
+
|
46 |
+
def main(args) -> None:
|
47 |
+
# Setup accelerator:
|
48 |
+
accelerator = Accelerator(split_batches=True)
|
49 |
+
set_seed(231)
|
50 |
+
device = accelerator.device
|
51 |
+
cfg = OmegaConf.load(args.config)
|
52 |
+
|
53 |
+
# Setup an experiment folder:
|
54 |
+
if accelerator.is_local_main_process:
|
55 |
+
exp_dir = cfg.train.exp_dir
|
56 |
+
os.makedirs(exp_dir, exist_ok=True)
|
57 |
+
ckpt_dir = os.path.join(exp_dir, "checkpoints")
|
58 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
59 |
+
print(f"Experiment directory created at {exp_dir}")
|
60 |
+
|
61 |
+
# Create model:
|
62 |
+
cldm: ControlLDM = instantiate_from_config(cfg.model.cldm)
|
63 |
+
sd = torch.load(cfg.train.sd_path, map_location="cpu")["state_dict"]
|
64 |
+
unused = cldm.load_pretrained_sd(sd)
|
65 |
+
if accelerator.is_local_main_process:
|
66 |
+
print(f"strictly load pretrained SD weight from {cfg.train.sd_path}\n"
|
67 |
+
f"unused weights: {unused}")
|
68 |
+
|
69 |
+
if cfg.train.resume:
|
70 |
+
cldm.load_controlnet_from_ckpt(torch.load(cfg.train.resume, map_location="cpu"))
|
71 |
+
if accelerator.is_local_main_process:
|
72 |
+
print(f"strictly load controlnet weight from checkpoint: {cfg.train.resume}")
|
73 |
+
else:
|
74 |
+
init_with_new_zero, init_with_scratch = cldm.load_controlnet_from_unet()
|
75 |
+
if accelerator.is_local_main_process:
|
76 |
+
print(f"strictly load controlnet weight from pretrained SD\n"
|
77 |
+
f"weights initialized with newly added zeros: {init_with_new_zero}\n"
|
78 |
+
f"weights initialized from scratch: {init_with_scratch}")
|
79 |
+
|
80 |
+
swinir: SwinIR = instantiate_from_config(cfg.model.swinir)
|
81 |
+
sd = {
|
82 |
+
(k[len("module."):] if k.startswith("module.") else k): v
|
83 |
+
for k, v in torch.load(cfg.train.swinir_path, map_location="cpu").items()
|
84 |
+
}
|
85 |
+
swinir.load_state_dict(sd, strict=True)
|
86 |
+
for p in swinir.parameters():
|
87 |
+
p.requires_grad = False
|
88 |
+
if accelerator.is_local_main_process:
|
89 |
+
print(f"load SwinIR from {cfg.train.swinir_path}")
|
90 |
+
|
91 |
+
diffusion: Diffusion = instantiate_from_config(cfg.model.diffusion)
|
92 |
+
|
93 |
+
# Setup optimizer:
|
94 |
+
opt = torch.optim.AdamW(cldm.controlnet.parameters(), lr=cfg.train.learning_rate)
|
95 |
+
|
96 |
+
# Setup data:
|
97 |
+
dataset = instantiate_from_config(cfg.dataset.train)
|
98 |
+
loader = DataLoader(
|
99 |
+
dataset=dataset, batch_size=cfg.train.batch_size,
|
100 |
+
num_workers=cfg.train.num_workers,
|
101 |
+
shuffle=True, drop_last=True
|
102 |
+
)
|
103 |
+
if accelerator.is_local_main_process:
|
104 |
+
print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}")
|
105 |
+
|
106 |
+
# Prepare models for training:
|
107 |
+
cldm.train().to(device)
|
108 |
+
swinir.eval().to(device)
|
109 |
+
diffusion.to(device)
|
110 |
+
cldm, opt, loader = accelerator.prepare(cldm, opt, loader)
|
111 |
+
pure_cldm: ControlLDM = accelerator.unwrap_model(cldm)
|
112 |
+
|
113 |
+
# Variables for monitoring/logging purposes:
|
114 |
+
global_step = 0
|
115 |
+
max_steps = cfg.train.train_steps
|
116 |
+
step_loss = []
|
117 |
+
epoch = 0
|
118 |
+
epoch_loss = []
|
119 |
+
sampler = SpacedSampler(diffusion.betas)
|
120 |
+
if accelerator.is_local_main_process:
|
121 |
+
writer = SummaryWriter(exp_dir)
|
122 |
+
print(f"Training for {max_steps} steps...")
|
123 |
+
|
124 |
+
while global_step < max_steps:
|
125 |
+
pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", total=len(loader))
|
126 |
+
for gt, lq, prompt in loader:
|
127 |
+
gt = rearrange(gt, "b h w c -> b c h w").contiguous().float().to(device)
|
128 |
+
lq = rearrange(lq, "b h w c -> b c h w").contiguous().float().to(device)
|
129 |
+
with torch.no_grad():
|
130 |
+
z_0 = pure_cldm.vae_encode(gt)
|
131 |
+
clean = swinir(lq)
|
132 |
+
cond = pure_cldm.prepare_condition(clean, prompt)
|
133 |
+
t = torch.randint(0, diffusion.num_timesteps, (z_0.shape[0],), device=device)
|
134 |
+
|
135 |
+
loss = diffusion.p_losses(cldm, z_0, t, cond)
|
136 |
+
opt.zero_grad()
|
137 |
+
accelerator.backward(loss)
|
138 |
+
opt.step()
|
139 |
+
|
140 |
+
accelerator.wait_for_everyone()
|
141 |
+
|
142 |
+
global_step += 1
|
143 |
+
step_loss.append(loss.item())
|
144 |
+
epoch_loss.append(loss.item())
|
145 |
+
pbar.update(1)
|
146 |
+
pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}")
|
147 |
+
|
148 |
+
# Log loss values:
|
149 |
+
if global_step % cfg.train.log_every == 0 and global_step > 0:
|
150 |
+
# Gather values from all processes
|
151 |
+
avg_loss = accelerator.gather(torch.tensor(step_loss, device=device).unsqueeze(0)).mean().item()
|
152 |
+
step_loss.clear()
|
153 |
+
if accelerator.is_local_main_process:
|
154 |
+
writer.add_scalar("loss/loss_simple_step", avg_loss, global_step)
|
155 |
+
|
156 |
+
# Save checkpoint:
|
157 |
+
if global_step % cfg.train.ckpt_every == 0 and global_step > 0:
|
158 |
+
if accelerator.is_local_main_process:
|
159 |
+
checkpoint = pure_cldm.controlnet.state_dict()
|
160 |
+
ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt"
|
161 |
+
torch.save(checkpoint, ckpt_path)
|
162 |
+
|
163 |
+
if global_step % cfg.train.image_every == 0 or global_step == 1:
|
164 |
+
N = 12
|
165 |
+
log_clean = clean[:N]
|
166 |
+
log_cond = {k:v[:N] for k, v in cond.items()}
|
167 |
+
log_gt, log_lq = gt[:N], lq[:N]
|
168 |
+
log_prompt = prompt[:N]
|
169 |
+
cldm.eval()
|
170 |
+
with torch.no_grad():
|
171 |
+
z = sampler.sample(
|
172 |
+
model=cldm, device=device, steps=50, batch_size=len(log_gt), x_size=z_0.shape[1:],
|
173 |
+
cond=log_cond, uncond=None, cfg_scale=1.0, x_T=None,
|
174 |
+
progress=accelerator.is_local_main_process, progress_leave=False
|
175 |
+
)
|
176 |
+
if accelerator.is_local_main_process:
|
177 |
+
for tag, image in [
|
178 |
+
("image/samples", (pure_cldm.vae_decode(z) + 1) / 2),
|
179 |
+
("image/gt", (log_gt + 1) / 2),
|
180 |
+
("image/lq", log_lq),
|
181 |
+
("image/condition", log_clean),
|
182 |
+
("image/condition_decoded", (pure_cldm.vae_decode(log_cond["c_img"]) + 1) / 2),
|
183 |
+
("image/prompt", (log_txt_as_img((512, 512), log_prompt) + 1) / 2)
|
184 |
+
]:
|
185 |
+
writer.add_image(tag, make_grid(image, nrow=4), global_step)
|
186 |
+
cldm.train()
|
187 |
+
accelerator.wait_for_everyone()
|
188 |
+
if global_step == max_steps:
|
189 |
+
break
|
190 |
+
|
191 |
+
pbar.close()
|
192 |
+
epoch += 1
|
193 |
+
avg_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device).unsqueeze(0)).mean().item()
|
194 |
+
epoch_loss.clear()
|
195 |
+
if accelerator.is_local_main_process:
|
196 |
+
writer.add_scalar("loss/loss_simple_epoch", avg_epoch_loss, global_step)
|
197 |
+
|
198 |
+
if accelerator.is_local_main_process:
|
199 |
+
print("done!")
|
200 |
+
writer.close()
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
parser = ArgumentParser()
|
205 |
+
parser.add_argument("--config", type=str, required=True)
|
206 |
+
args = parser.parse_args()
|
207 |
+
main(args)
|