File size: 22,472 Bytes
a71d323 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 |
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Subset
import argparse
import logging
from colorlog import ColoredFormatter
import tqdm
from itertools import chain
import wandb
import random
import numpy as np
from pathlib import Path
from einops import rearrange
from causalvideovae.model import Refiner, EMA, CausalVAEModel
from causalvideovae.utils.utils import RealVideoDataset
from causalvideovae.model.dataset_videobase import VideoDataset
from causalvideovae.model.utils.module_utils import resolve_str_to_obj
from causalvideovae.model.utils.video_utils import tensor_to_video
import time
try:
import lpips
except:
raise Exception("Need lpips to valid.")
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def ddp_setup():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def setup_logger(rank):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = ColoredFormatter(
f"[rank{rank}] %(log_color)s%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
},
reset=True,
style="%",
)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(stream_handler)
return logger
def check_unused_params(model):
unused_params = []
for name, param in model.named_parameters():
if param.grad is None:
unused_params.append(name)
return unused_params
def set_requires_grad_optimizer(optimizer, requires_grad):
for param_group in optimizer.param_groups:
for param in param_group["params"]:
param.requires_grad = requires_grad
def total_params(model):
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params_in_millions = total_params / 1e6
return int(total_params_in_millions)
def get_exp_name(args):
return f"{args.exp_name}-lr{args.lr:.2e}-bs{args.batch_size}-rs{args.resolution}-sr{args.sample_rate}-fr{args.num_frames}"
def set_train(modules):
for module in modules:
module.train()
def set_eval(modules):
for module in modules:
module.eval()
def set_modules_requires_grad(modules, requires_grad):
for module in modules:
module.requires_grad_(requires_grad)
def save_checkpoint(
epoch,
batch_idx,
optimizer_state,
state_dict,
scaler_state,
checkpoint_dir,
filename="checkpoint.ckpt",
ema_state_dict={}
):
filepath = checkpoint_dir / Path(filename)
torch.save(
{
"epoch": epoch,
"batch_idx": batch_idx,
"optimizer_state": optimizer_state,
"state_dict": state_dict,
"ema_state_dict": ema_state_dict,
"scaler_state": scaler_state,
},
filepath,
)
return filepath
def valid(rank, model, vae, val_dataloader, precision, args):
if args.eval_lpips:
lpips_model = lpips.LPIPS(net='alex', spatial=True)
lpips_model.to(rank)
lpips_model = DDP(lpips_model, device_ids=[rank])
lpips_model.requires_grad_(False)
lpips_model.eval()
bar = None
if rank == 0:
bar = tqdm.tqdm(total=len(val_dataloader), desc="Validation...")
psnr_list = []
lpips_list = []
video_log = []
num_video_log = args.eval_num_video_log
with torch.no_grad():
for batch_idx, batch in enumerate(val_dataloader):
inputs = batch['video'].to(rank)
with torch.cuda.amp.autocast(dtype=precision):
latents = vae.encode(inputs).sample()
video_recon = vae.decode(latents)
refines = model(video_recon)
# Upload videos
if rank == 0:
for i in range(len(refines)):
if num_video_log <= 0:
break
refine_video = tensor_to_video(refines[i])
video_log.append(refine_video)
num_video_log -= 1
inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
refines = rearrange(refines, "b c t h w -> (b t) c h w").contiguous()
# Calculate PSNR
mse = torch.mean(torch.square(inputs - refines), dim=(1,2,3))
psnr = 20 * torch.log10(1 / torch.sqrt(mse))
psnr = psnr.mean().detach().cpu().item()
# Calculate LPIPS
if args.eval_lpips:
lpips_score = lpips_model.forward(inputs, refines).mean().detach().cpu().item()
lpips_list.append(lpips_score)
psnr_list.append(psnr)
if rank == 0:
bar.update()
# Release gpus memory
torch.cuda.empty_cache()
return psnr_list, lpips_list, video_log
def gather_valid_result(psnr_list, lpips_list, video_log_list, rank, world_size):
gathered_psnr_list = [None for _ in range(world_size)]
gathered_lpips_list = [None for _ in range(world_size)]
gathered_video_logs = [None for _ in range(world_size)]
dist.all_gather_object(gathered_psnr_list, psnr_list)
dist.all_gather_object(gathered_lpips_list, lpips_list)
dist.all_gather_object(gathered_video_logs, video_log_list)
return np.array(gathered_psnr_list).mean(), np.array(gathered_lpips_list).mean(), list(chain(*gathered_video_logs))
def train(args):
# Setup logger
ddp_setup()
rank = int(os.environ["LOCAL_RANK"])
logger = setup_logger(rank)
# Init
ckpt_dir = Path(args.ckpt_dir) / Path(get_exp_name(args))
if rank == 0:
try:
ckpt_dir.mkdir(exist_ok=False, parents=True)
except:
logger.warning(f"`{ckpt_dir}` exists!")
time.sleep(5)
logger.warning("Connecting to WANDB...")
wandb.init(
project=os.environ.get("WANDB_PROJECT", "causalvideovae"),
config=args,
name=get_exp_name(args)
)
dist.barrier()
# Load generator model
if args.pretrained_model_name_or_path is not None:
if rank == 0:
logger.warning(
f"You are loading a checkpoint from `{args.pretrained_model_name_or_path}`."
)
model = Refiner.from_pretrained(
args.pretrained_model_name_or_path, ignore_mismatched_sizes=False
)
elif args.model_config is not None:
if rank == 0:
logger.warning(f"Model will be inited randomly.")
model = Refiner.from_config(args.model_config)
else:
raise Exception(
"You should set either `--pretrained_model_name_or_path` or `--model_config`"
)
# Load discriminator model
disc_cls = resolve_str_to_obj(args.disc_cls, append=False)
logger.warning(f"disc_class: {args.disc_cls} perceptual_weight: {args.perceptual_weight} loss_type: {args.loss_type}")
disc = disc_cls(
disc_start=args.disc_start,
disc_weight=args.disc_weight,
logvar_init=args.logvar_init,
perceptual_weight=args.perceptual_weight,
loss_type=args.loss_type
)
# DDP
model = model.to(rank)
vae = CausalVAEModel.from_pretrained(args.vae_path, ignore_mismatched_sizes=False)
vae.requires_grad_(False)
vae = vae.to(rank).to(torch.bfloat16)
model = DDP(
model, device_ids=[rank], find_unused_parameters=args.find_unused_parameters
)
disc = disc.to(rank)
disc = DDP(
disc, device_ids=[rank], find_unused_parameters=args.find_unused_parameters
)
dataset = VideoDataset(
args.video_path,
sequence_length=args.num_frames,
resolution=args.resolution,
sample_rate=args.sample_rate,
dynamic_sample=args.dynamic_sample,
)
ddp_sampler = DistributedSampler(dataset)
dataloader = DataLoader(
dataset, batch_size=args.batch_size, sampler=ddp_sampler, pin_memory=True, num_workers=args.dataset_num_worker
)
val_dataset = RealVideoDataset(
real_video_dir=args.eval_video_path,
num_frames=args.eval_num_frames,
sample_rate=args.eval_sample_rate,
crop_size=args.eval_resolution,
resolution=args.eval_resolution,
)
indices = range(args.eval_subset_size)
val_dataset = Subset(val_dataset, indices=indices)
val_sampler = DistributedSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, batch_size=args.eval_batch_size, sampler=val_sampler, pin_memory=True)
# Optimizer
modules_to_train = [module for module in model.module.get_decoder()]
if not args.freeze_encoder:
modules_to_train += [module for module in model.module.get_encoder()]
else:
for module in model.module.get_encoder():
module.eval()
module.requires_grad_(False)
logger.warning("Encoder is freezed!")
parameters_to_train = []
for module in modules_to_train:
parameters_to_train += module.parameters()
gen_optimizer = torch.optim.Adam(parameters_to_train, lr=args.lr)
disc_optimizer = torch.optim.Adam(
disc.module.discriminator.parameters(), lr=args.lr
)
# AMP scaler
scaler = torch.cuda.amp.GradScaler()
precision = torch.bfloat16
if args.mix_precision == "fp16":
precision = torch.float16
elif args.mix_precision == "fp32":
precision = torch.float32
# Load from checkpoint
start_epoch = 0
start_batch_idx = 0
if args.resume_from_checkpoint:
if not os.path.isfile(args.resume_from_checkpoint):
raise Exception(
f"Make sure `{args.resume_from_checkpoint}` is a ckpt file."
)
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
if "ema_state_dict" in checkpoint and len(checkpoint['ema_state_dict']) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0:
logger.info("Load from EMA state dict! If you want to load from original state dict, you should set NOT_USE_EMA_MODEL=1.")
sd = checkpoint["ema_state_dict"]
sd = {key.replace("module.", ""): value for key, value in sd.items()}
model.module.load_state_dict(sd, strict=True)
else:
if "gen_model" in sd["state_dict"]:
sd = sd["state_dict"]["gen_model"]
else:
sd = sd["state_dict"]
model.module.load_state_dict(sd)
disc.module.load_state_dict(checkpoint["state_dict"]["dics_model"], strict=False)
if not args.not_resume_training_process:
scaler.load_state_dict(checkpoint["scaler_state"])
gen_optimizer.load_state_dict(checkpoint["optimizer_state"]["gen_optimizer"])
disc_optimizer.load_state_dict(checkpoint["optimizer_state"]["disc_optimizer"])
start_epoch = checkpoint["epoch"]
start_batch_idx = checkpoint.get("batch_idx", 0)
logger.info(
f"Checkpoint loaded from {args.resume_from_checkpoint}, starting from epoch {start_epoch} batch {start_batch_idx}"
)
else:
logger.warning(
f"Checkpoint loaded from {args.resume_from_checkpoint}, starting from epoch {start_epoch} batch {start_batch_idx}. But training process is not resumed."
)
if args.ema:
logger.warning(f"Start with EMA. EMA decay = {args.ema_decay}.")
ema = EMA(model, args.ema_decay)
ema.register()
# Training loop
logger.info("Prepared!")
dist.barrier()
if rank == 0:
logger.info(f"=== Model Params ===")
logger.info(f"Generator:\t\t{total_params(model.module)}M")
logger.info(f"\t- Encoder:\t{total_params(model.module.encoder):d}M")
logger.info(f"\t- Decoder:\t{total_params(model.module.decoder):d}M")
logger.info(f"Discriminator:\t{total_params(disc.module):d}M")
logger.info(f"===========")
logger.info(f"Precision is set to: {args.mix_precision}!")
logger.info("Start training!")
# Training Bar
bar_desc = ""
bar = None
if rank == 0:
max_steps = (
args.epochs * len(dataloader) if args.max_steps is None else args.max_steps
)
bar = tqdm.tqdm(total=max_steps, desc=bar_desc.format(current_epoch=0, loss=0))
bar_desc = "Epoch: {current_epoch}, Loss: {loss}"
logger.warning("Training Details: ")
logger.warning(f" Max steps: {max_steps}")
logger.warning(f" Dataset Samples: {len(dataloader)}")
logger.warning(
f" Total Batch Size: {args.batch_size} * {os.environ['WORLD_SIZE']}"
)
dist.barrier()
# Training Loop
num_epochs = args.epochs
current_step = 1
def update_bar(bar):
if rank == 0:
bar.desc = bar_desc.format(current_epoch=epoch, loss=f"-")
bar.update()
for epoch in range(num_epochs):
set_train(modules_to_train)
ddp_sampler.set_epoch(epoch) # Shuffle data at every epoch
for batch_idx, batch in enumerate(dataloader):
if epoch <= start_epoch and batch_idx < start_batch_idx:
update_bar(bar)
current_step += 1
continue
inputs = batch["video"].to(rank)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=precision):
latents = vae.encode(inputs).sample()
video_recon = vae.decode(latents)
if (
current_step % 2 == 1
and current_step >= disc.module.discriminator_iter_start
):
set_modules_requires_grad(modules_to_train, False)
step_gen = False
step_dis = True
else:
set_modules_requires_grad(modules_to_train, True)
step_gen = True
step_dis = False
assert (
step_gen or step_dis
), "You should backward either Gen or Dis in a step."
with torch.cuda.amp.autocast(dtype=precision):
outputs = model(video_recon)
# Generator Step
if step_gen:
with torch.cuda.amp.autocast(dtype=precision):
g_loss, g_log = disc(
inputs,
outputs,
optimizer_idx=0,
global_step=current_step,
last_layer=model.module.get_last_layer(),
split="train",
)
gen_optimizer.zero_grad()
scaler.scale(g_loss).backward()
scaler.step(gen_optimizer)
scaler.update()
if args.ema:
ema.update()
if rank == 0 and current_step % args.log_steps == 0:
wandb.log({"train/generator_loss": g_loss.item()}, step=current_step)
# Discriminator Step
if step_dis:
with torch.cuda.amp.autocast(dtype=precision):
d_loss, d_log = disc(
inputs,
outputs,
optimizer_idx=1,
global_step=current_step,
last_layer=None,
split="train",
)
disc_optimizer.zero_grad()
scaler.scale(d_loss).backward()
scaler.step(disc_optimizer)
scaler.update()
if rank == 0 and current_step % args.log_steps == 0:
wandb.log({"train/discriminator_loss": d_loss.item()}, step=current_step)
def valid_model(model, vae, name=""):
set_eval(modules_to_train)
psnr_list, lpips_list, video_log = valid(rank, model, vae, val_dataloader, precision, args)
valid_psnr, valid_lpips, valid_video_log = gather_valid_result(psnr_list, lpips_list, video_log, rank, dist.get_world_size())
if rank == 0:
name = "_" + name if name != "" else name
wandb.log({f"val{name}/recon": wandb.Video(np.array(valid_video_log), fps=10)}, step=current_step)
wandb.log({f"val{name}/psnr": valid_psnr}, step=current_step)
wandb.log({f"val{name}/lpips": valid_lpips}, step=current_step)
logger.info(f"{name} Validation done.")
if current_step % args.eval_steps == 0 or current_step == 1:
if rank == 0:
logger.info("Starting validation...")
valid_model(model, vae)
if args.ema:
ema.apply_shadow()
valid_model(model, vae, "ema")
ema.restore()
# Checkpoint
if current_step % args.save_ckpt_step == 0 and rank == 0:
file_path = save_checkpoint(
epoch,
batch_idx,
{
"gen_optimizer": gen_optimizer.state_dict(),
"disc_optimizer": disc_optimizer.state_dict(),
},
{
"gen_model": model.module.state_dict(),
"dics_model": disc.module.state_dict(),
},
scaler.state_dict(),
ckpt_dir,
f"checkpoint-{current_step}.ckpt",
ema_state_dict=ema.shadow if args.ema else {}
)
logger.info(f"Checkpoint has been saved to `{file_path}`.")
# Update step
update_bar(bar)
current_step += 1
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="Distributed Training")
# Exp setting
parser.add_argument(
"--exp_name", type=str, default="test", help="number of epochs to train"
)
parser.add_argument("--seed", type=int, default=1234, help="seed")
# Training setting
parser.add_argument(
"--epochs", type=int, default=10, help="number of epochs to train"
)
parser.add_argument(
"--max_steps", type=int, default=None, help="number of epochs to train"
)
parser.add_argument("--save_ckpt_step", type=int, default=1000, help="")
parser.add_argument("--ckpt_dir", type=str, default="./results/", help="")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size for training"
)
parser.add_argument("--lr", type=float, default=1e-5, help="learning rate")
parser.add_argument("--log_steps", type=int, default=5, help="log steps")
parser.add_argument("--freeze_encoder", action="store_true", help="")
# Data
parser.add_argument("--video_path", type=str, default=None, help="")
parser.add_argument("--num_frames", type=int, default=17, help="")
parser.add_argument("--resolution", type=int, default=512, help="")
parser.add_argument("--sample_rate", type=int, default=1, help="")
parser.add_argument("--dynamic_sample", type=bool, default=False, help="")
# Generator model
parser.add_argument("--find_unused_parameters", action="store_true", help="")
parser.add_argument(
"--pretrained_model_name_or_path", type=str, default=None, help=""
)
parser.add_argument(
"--vae_path", type=str, default=None, help=""
)
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="")
parser.add_argument("--not_resume_training_process", action="store_true", help="")
parser.add_argument("--model_config", type=str, default=None, help="")
parser.add_argument(
"--mix_precision",
type=str,
default="bf16",
choices=["fp16", "bf16", "fp32"],
help="precision for training",
)
# Discriminator Model
parser.add_argument("--load_disc_from_checkpoint", type=str, default=None, help="")
parser.add_argument(
"--disc_cls",
type=str,
default="causalvideovae.model.losses.LPIPSWithDiscriminator3D",
help="",
)
parser.add_argument("--disc_start", type=int, default=5, help="")
parser.add_argument("--disc_weight", type=float, default=0.5, help="")
parser.add_argument("--kl_weight", type=float, default=1e-06, help="")
parser.add_argument("--perceptual_weight", type=float, default=1.0, help="")
parser.add_argument("--loss_type", type=str, default="l1", help="")
parser.add_argument("--logvar_init", type=float, default=0.0, help="")
# Validation
parser.add_argument("--eval_steps", type=int, default=1000, help="")
parser.add_argument("--eval_video_path", type=str, default=None, help="")
parser.add_argument("--eval_num_frames", type=int, default=17, help="")
parser.add_argument("--eval_resolution", type=int, default=256, help="")
parser.add_argument("--eval_sample_rate", type=int, default=1, help="")
parser.add_argument("--eval_batch_size", type=int, default=8, help="")
parser.add_argument("--eval_subset_size", type=int, default=50, help="")
parser.add_argument("--eval_num_video_log", type=int, default=2, help="")
parser.add_argument("--eval_lpips", action="store_true", help="")
# Dataset
parser.add_argument("--dataset_num_worker", type=int, default=16, help="")
# EMA
parser.add_argument("--ema", action="store_true", help="")
parser.add_argument("--ema_decay", type=float, default=0.999, help="")
args = parser.parse_args()
set_random_seed(args.seed)
train(args)
if __name__ == "__main__":
main()
|