Semanticist_AR / semanticist /engine /diffusion_trainer.py
tennant's picture
upload
7b0a1ef
import os, torch
import os.path as osp
import shutil
from tqdm.auto import tqdm
from einops import rearrange
from accelerate import Accelerator
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader, random_split, DistributedSampler
from semanticist.utils.logger import SmoothedValue, MetricLogger, empty_cache
from accelerate.utils import DistributedDataParallelKwargs
from torchmetrics.functional.image import (
peak_signal_noise_ratio as psnr,
structural_similarity_index_measure as ssim
)
from semanticist.engine.trainer_utils import (
instantiate_from_config, concat_all_gather,
save_img_batch, get_fid_stats,
EMAModel, PaddedDataset, create_scheduler, load_state_dict,
load_safetensors, setup_result_folders, create_optimizer
)
class DiffusionTrainer:
def __init__(
self,
model,
dataset,
test_dataset=None,
test_only=False,
num_epoch=400,
valid_size=32,
blr=1e-4,
cosine_lr=True,
lr_min=0,
warmup_epochs=100,
warmup_steps=None,
warmup_lr_init=0,
decay_steps=None,
batch_size=32,
eval_bs=32,
test_bs=64,
num_workers=8,
pin_memory=False,
max_grad_norm=None,
grad_accum_steps=1,
precision='bf16',
save_every=10000,
sample_every=1000,
fid_every=50000,
result_folder=None,
log_dir="./log",
cfg=3.0,
test_num_slots=None,
eval_fid=False,
fid_stats=None,
enable_ema=False,
compile=False,
):
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
self.accelerator = Accelerator(
kwargs_handlers=[kwargs],
mixed_precision=precision,
gradient_accumulation_steps=grad_accum_steps,
log_with="tensorboard",
project_dir=log_dir,
)
self.model = instantiate_from_config(model)
self.num_slots = model.params.num_slots
if test_dataset is not None:
test_dataset = instantiate_from_config(test_dataset)
self.test_ds = test_dataset
# Calculate padded dataset size to ensure even distribution
total_size = len(test_dataset)
world_size = self.accelerator.num_processes
padding_size = world_size * test_bs - (total_size % (world_size * test_bs))
self.test_dataset_size = total_size
self.test_ds = PaddedDataset(self.test_ds, padding_size)
self.test_dl = DataLoader(
self.test_ds,
batch_size=test_bs,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=False,
drop_last=True,
)
if self.accelerator.is_main_process:
print(f"test dataset size: {len(test_dataset)}, test batch size: {test_bs}")
else:
self.test_dl = None
self.test_only = test_only
if not test_only:
dataset = instantiate_from_config(dataset)
train_size = len(dataset) - valid_size
self.train_ds, self.valid_ds = random_split(
dataset,
[train_size, valid_size],
generator=torch.Generator().manual_seed(42),
)
if self.accelerator.is_main_process:
print(f"train dataset size: {train_size}, valid dataset size: {valid_size}")
sampler = DistributedSampler(
self.train_ds,
num_replicas=self.accelerator.num_processes,
rank=self.accelerator.process_index,
shuffle=True,
)
self.train_dl = DataLoader(
self.train_ds,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
self.valid_dl = DataLoader(
self.valid_ds,
batch_size=eval_bs,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False,
)
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
lr = blr * effective_bs / 256
if self.accelerator.is_main_process:
print(f"Effective batch size is {effective_bs}")
self.g_optim = create_optimizer(self.model, weight_decay=0.05, learning_rate=lr,) # accelerator=self.accelerator)
if warmup_epochs is not None:
warmup_steps = warmup_epochs * len(self.train_dl)
self.g_sched = create_scheduler(
self.g_optim,
num_epoch,
len(self.train_dl),
lr_min,
warmup_steps,
warmup_lr_init,
decay_steps,
cosine_lr
)
self.accelerator.register_for_checkpointing(self.g_sched)
self.model, self.g_optim, self.g_sched = self.accelerator.prepare(self.model, self.g_optim, self.g_sched)
else:
self.model, self.test_dl = self.accelerator.prepare(self.model, self.test_dl)
self.steps = 0
self.loaded_steps = -1
if compile:
_model = self.accelerator.unwrap_model(self.model)
_model.vae = torch.compile(_model.vae, mode="reduce-overhead")
_model.dit = torch.compile(_model.dit, mode="reduce-overhead")
# _model.encoder = torch.compile(_model.encoder, mode="reduce-overhead") # nan loss when compiled together with dit, no idea why
_model.encoder2slot = torch.compile(_model.encoder2slot, mode="reduce-overhead")
self.enable_ema = enable_ema
if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.model), self.device)
self.accelerator.register_for_checkpointing(self.ema_model)
self._load_checkpoint(model.params.ckpt_path)
if self.test_only:
self.steps = self.loaded_steps
self.num_epoch = num_epoch
self.save_every = save_every
self.sample_every = sample_every
self.fid_every = fid_every
self.max_grad_norm = max_grad_norm
self.cfg = cfg
self.test_num_slots = test_num_slots
if self.test_num_slots is not None:
self.test_num_slots = min(self.test_num_slots, self.num_slots)
else:
self.test_num_slots = self.num_slots
eval_fid = eval_fid or model.params.eval_fid # legacy
self.eval_fid = eval_fid
if eval_fid:
if fid_stats is None:
fid_stats = model.params.fid_stats # legacy
assert fid_stats is not None
assert test_dataset is not None
self.fid_stats = fid_stats
self.result_folder = result_folder
self.model_saved_dir, self.image_saved_dir = setup_result_folders(result_folder)
@property
def device(self):
return self.accelerator.device
def _load_checkpoint(self, ckpt_path=None):
if ckpt_path is None or not osp.exists(ckpt_path):
return
model = self.accelerator.unwrap_model(self.model)
if osp.isdir(ckpt_path):
# ckpt_path is something like 'path/to/models/step10/'
self.loaded_steps = int(
ckpt_path.split("step")[-1].split("/")[0]
)
if not self.test_only:
self.accelerator.load_state(ckpt_path)
else:
if self.enable_ema:
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
if osp.exists(model_path):
state_dict = torch.load(model_path, map_location="cpu")
load_state_dict(state_dict, model)
if self.accelerator.is_main_process:
print(f"Loaded ema model from {model_path}")
else:
model_path = osp.join(ckpt_path, "model.safetensors")
if osp.exists(model_path):
load_safetensors(model_path, model)
else:
# ckpt_path is something like 'path/to/models/step10.pt'
if ckpt_path.endswith(".safetensors"):
load_safetensors(ckpt_path, model)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
load_state_dict(state_dict, model)
if self.accelerator.is_main_process:
print(f"Loaded checkpoint from {ckpt_path}")
def train(self, config=None):
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
if self.accelerator.is_main_process:
print(f"number of learnable parameters: {n_parameters//1e6}M")
if config is not None:
# save the config
from omegaconf import OmegaConf
if isinstance(config, str) and osp.exists(config):
# If it's a path, copy the file to config.yaml
shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
else:
# If it's an OmegaConf object, dump it
config_save_path = osp.join(self.result_folder, "config.yaml")
OmegaConf.save(config, config_save_path)
self.accelerator.init_trackers("semanticist")
if self.test_only:
empty_cache()
self.evaluate()
self.accelerator.wait_for_everyone()
empty_cache()
return
for epoch in range(self.num_epoch):
if ((epoch + 1) * len(self.train_dl)) <= self.loaded_steps:
if self.accelerator.is_main_process:
print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
self.steps += len(self.train_dl)
continue
if self.steps < self.loaded_steps:
for _ in self.train_dl:
self.steps += 1
if self.steps >= self.loaded_steps:
break
self.accelerator.unwrap_model(self.model).current_epoch = epoch
self.model.train() # Set model to training mode
logger = MetricLogger(delimiter=" ")
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}/{}]'.format(epoch, self.num_epoch)
print_freq = 20
for data_iter_step, batch in enumerate(logger.log_every(self.train_dl, print_freq, header)):
img, _ = batch
img = img.to(self.device, non_blocking=True)
self.steps += 1
with self.accelerator.accumulate(self.model):
with self.accelerator.autocast():
if self.steps == 1:
print(f"Training batch size: {img.size(0)}")
print(f"Hello from index {self.accelerator.local_process_index}")
losses = self.model(img, epoch=epoch)
# combine
loss = sum([v for _, v in losses.items()])
self.accelerator.backward(loss)
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.g_optim.step()
if self.g_sched is not None:
self.g_sched.step_update(self.steps)
self.g_optim.zero_grad()
self.accelerator.wait_for_everyone()
# update ema with state dict
if self.enable_ema:
self.ema_model.update(self.accelerator.unwrap_model(self.model))
for key, value in losses.items():
logger.update(**{key: value.item()})
logger.update(lr=self.g_optim.param_groups[0]["lr"])
if self.steps % self.save_every == 0:
self.save()
if (self.steps % self.sample_every == 0) or (self.steps % self.fid_every == 0):
empty_cache()
self.evaluate()
self.accelerator.wait_for_everyone()
empty_cache()
write_dict = dict(epoch=epoch)
for key, value in losses.items(): # omitted all_gather here
write_dict.update(**{key: value.item()})
write_dict.update(lr=self.g_optim.param_groups[0]["lr"])
self.accelerator.log(write_dict, step=self.steps)
logger.synchronize_between_processes()
if self.accelerator.is_main_process:
print("Averaged stats:", logger)
self.accelerator.end_training()
self.save()
if self.accelerator.is_main_process:
print("Train finished!")
def save(self):
self.accelerator.wait_for_everyone()
self.accelerator.save_state(
os.path.join(self.model_saved_dir, f"step{self.steps}")
)
@torch.no_grad()
def evaluate(self):
self.model.eval()
if not self.test_only:
with tqdm(
self.valid_dl,
dynamic_ncols=True,
disable=not self.accelerator.is_main_process,
) as valid_dl:
for batch_i, batch in enumerate(valid_dl):
if isinstance(batch, tuple) or isinstance(batch, list):
img, targets = batch[0], batch[1]
else:
img = batch
with self.accelerator.autocast():
rec = self.model(img, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
imgs_and_recs = imgs_and_recs.detach().cpu().float()
grid = make_grid(
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
)
if self.accelerator.is_main_process:
save_image(
grid,
os.path.join(
self.image_saved_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}.jpg"
),
)
if self.cfg != 1.0:
with self.accelerator.autocast():
rec = self.model(img, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
imgs_and_recs = imgs_and_recs.detach().cpu().float()
grid = make_grid(
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
)
if self.accelerator.is_main_process:
save_image(
grid,
os.path.join(
self.image_saved_dir, f"step_{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}_{batch_i}.jpg"
),
)
if (self.eval_fid and self.test_dl is not None) and (self.test_only or (self.steps % self.fid_every == 0)):
real_dir = "./dataset/imagenet/val256"
rec_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_slots{self.test_num_slots}")
os.makedirs(rec_dir, exist_ok=True)
if self.cfg != 1.0:
rec_cfg_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}")
os.makedirs(rec_cfg_dir, exist_ok=True)
def process_batch(cfg_value, save_dir, header):
logger = MetricLogger(delimiter=" ")
print_freq = 5
psnr_values = []
ssim_values = []
total_processed = 0
for batch_i, batch in enumerate(logger.log_every(self.test_dl, print_freq, header)):
imgs, targets = (batch[0], batch[1]) if isinstance(batch, (tuple, list)) else (batch, None)
# Skip processing if we've already processed all real samples
if total_processed >= self.test_dataset_size:
break
imgs = imgs.to(self.device, non_blocking=True)
if targets is not None:
targets = targets.to(self.device, non_blocking=True)
with self.accelerator.autocast():
recs = self.model(imgs, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
psnr_val = psnr(recs, imgs, data_range=1.0)
ssim_val = ssim(recs, imgs, data_range=1.0)
recs = concat_all_gather(recs).detach()
psnr_val = concat_all_gather(psnr_val.view(1))
ssim_val = concat_all_gather(ssim_val.view(1))
# Remove padding after gathering from all GPUs
samples_in_batch = min(
recs.size(0), # Always use the gathered size
self.test_dataset_size - total_processed
)
recs = recs[:samples_in_batch]
psnr_val = psnr_val[:samples_in_batch]
ssim_val = ssim_val[:samples_in_batch]
psnr_values.append(psnr_val)
ssim_values.append(ssim_val)
if self.accelerator.is_main_process:
rec_paths = [os.path.join(save_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}_{j}_rec_cfg_{cfg_value}_slots{self.test_num_slots}.png")
for j in range(recs.size(0))]
save_img_batch(recs.cpu(), rec_paths)
total_processed += samples_in_batch
self.accelerator.wait_for_everyone()
return torch.cat(psnr_values).mean(), torch.cat(ssim_values).mean()
# Helper function to calculate and log metrics
def calculate_and_log_metrics(real_dir, rec_dir, cfg_value, psnr_val, ssim_val):
if self.accelerator.is_main_process:
metrics_dict = get_fid_stats(real_dir, rec_dir, self.fid_stats)
fid = metrics_dict["frechet_inception_distance"]
inception_score = metrics_dict["inception_score_mean"]
metric_prefix = "fid"
isc_prefix = "isc"
self.accelerator.log({
metric_prefix: fid,
isc_prefix: inception_score,
f"psnr": psnr_val,
f"ssim": ssim_val,
"cfg": cfg_value
}, step=self.steps)
print(f"{'CFG: {cfg_value}'} "
f"FID: {fid:.2f}, ISC: {inception_score:.2f}, "
f"PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
# Process without CFG
if self.cfg == 1.0 or not self.test_only:
psnr_val, ssim_val = process_batch(1.0, rec_dir, 'Testing: w/o CFG')
calculate_and_log_metrics(real_dir, rec_dir, 1.0, psnr_val, ssim_val)
# Process with CFG if needed
if self.cfg != 1.0:
psnr_val, ssim_val = process_batch(self.cfg, rec_cfg_dir, 'Testing: w/ CFG')
calculate_and_log_metrics(real_dir, rec_cfg_dir, self.cfg, psnr_val, ssim_val)
# Cleanup
if self.accelerator.is_main_process:
shutil.rmtree(rec_dir)
if self.cfg != 1.0:
shutil.rmtree(rec_cfg_dir)
self.model.train()