multimodalart's picture
Upload 80 files
0fd2f06 verified
import gc
import logging
from utils.dataset import ShardingLMDBDataset, cycle
from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
from utils.misc import (
set_seed,
merge_dict_list
)
import torch.distributed as dist
from omegaconf import OmegaConf
from model import GAN
import torch
import wandb
import time
import os
class Trainer:
def __init__(self, config):
self.config = config
self.step = 0
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
launch_distributed_job()
global_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
self.device = torch.cuda.current_device()
self.is_main_process = global_rank == 0
self.causal = config.causal
self.disable_wandb = config.disable_wandb
# Configuration for discriminator warmup
self.discriminator_warmup_steps = getattr(config, "discriminator_warmup_steps", 0)
self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
if self.in_discriminator_warmup and self.is_main_process:
print(f"Starting with discriminator warmup for {self.discriminator_warmup_steps} steps")
self.loss_scale = getattr(config, "loss_scale", 1.0)
# use a random seed for the training
if config.seed == 0:
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
dist.broadcast(random_seed, src=0)
config.seed = random_seed.item()
set_seed(config.seed + global_rank)
if self.is_main_process and not self.disable_wandb:
wandb.login(host=config.wandb_host, key=config.wandb_key)
wandb.init(
config=OmegaConf.to_container(config, resolve=True),
name=config.config_name,
mode="online",
entity=config.wandb_entity,
project=config.wandb_project,
dir=config.wandb_save_dir
)
self.output_path = config.logdir
# Step 2: Initialize the model and optimizer
self.model = GAN(config, device=self.device)
self.model.generator = fsdp_wrap(
self.model.generator,
sharding_strategy=config.sharding_strategy,
mixed_precision=config.mixed_precision,
wrap_strategy=config.generator_fsdp_wrap_strategy
)
self.model.fake_score = fsdp_wrap(
self.model.fake_score,
sharding_strategy=config.sharding_strategy,
mixed_precision=config.mixed_precision,
wrap_strategy=config.fake_score_fsdp_wrap_strategy
)
self.model.text_encoder = fsdp_wrap(
self.model.text_encoder,
sharding_strategy=config.sharding_strategy,
mixed_precision=config.mixed_precision,
wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
)
if not config.no_visualize or config.load_raw_video:
self.model.vae = self.model.vae.to(
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
self.generator_optimizer = torch.optim.AdamW(
[param for param in self.model.generator.parameters()
if param.requires_grad],
lr=config.gen_lr,
betas=(config.beta1, config.beta2)
)
# Create separate parameter groups for the fake_score network
# One group for parameters with "_cls_pred_branch" or "_gan_ca_blocks" in the name
# and another group for all other parameters
fake_score_params = []
discriminator_params = []
for name, param in self.model.fake_score.named_parameters():
if param.requires_grad:
if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
discriminator_params.append(param)
else:
fake_score_params.append(param)
# Use the special learning rate for the special parameter group
# and the default critic learning rate for other parameters
self.critic_param_groups = [
{'params': fake_score_params, 'lr': config.critic_lr},
{'params': discriminator_params, 'lr': config.critic_lr * config.discriminator_lr_multiplier}
]
if self.in_discriminator_warmup:
self.critic_optimizer = torch.optim.AdamW(
self.critic_param_groups,
betas=(0.9, config.beta2_critic)
)
else:
self.critic_optimizer = torch.optim.AdamW(
self.critic_param_groups,
betas=(config.beta1_critic, config.beta2_critic)
)
# Step 3: Initialize the dataloader
self.data_path = config.data_path
dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=True, drop_last=True)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=config.batch_size,
sampler=sampler,
num_workers=8)
if dist.get_rank() == 0:
print("DATASET SIZE %d" % len(dataset))
self.dataloader = cycle(dataloader)
##############################################################################################################
# 6. Set up EMA parameter containers
rename_param = (
lambda name: name.replace("_fsdp_wrapped_module.", "")
.replace("_checkpoint_wrapped_module.", "")
.replace("_orig_mod.", "")
)
self.name_to_trainable_params = {}
for n, p in self.model.generator.named_parameters():
if not p.requires_grad:
continue
renamed_n = rename_param(n)
self.name_to_trainable_params[renamed_n] = p
ema_weight = config.ema_weight
self.generator_ema = None
if (ema_weight is not None) and (ema_weight > 0.0):
print(f"Setting up EMA with weight {ema_weight}")
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
##############################################################################################################
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
if getattr(config, "generator_ckpt", False):
print(f"Loading pretrained generator from {config.generator_ckpt}")
state_dict = torch.load(config.generator_ckpt, map_location="cpu")
if "generator" in state_dict:
state_dict = state_dict["generator"]
elif "model" in state_dict:
state_dict = state_dict["model"]
self.model.generator.load_state_dict(
state_dict, strict=True
)
if hasattr(config, "load"):
resume_ckpt_path_critic = os.path.join(config.load, "critic")
resume_ckpt_path_generator = os.path.join(config.load, "generator")
else:
resume_ckpt_path_critic = "none"
resume_ckpt_path_generator = "none"
_, _ = self.checkpointer_critic.try_best_load(
resume_ckpt_path=resume_ckpt_path_critic,
)
self.step, _ = self.checkpointer_generator.try_best_load(
resume_ckpt_path=resume_ckpt_path_generator,
force_start_w_ema=config.force_start_w_ema,
force_reset_zero_step=config.force_reset_zero_step,
force_reinit_ema=config.force_reinit_ema,
skip_optimizer_scheduler=config.skip_optimizer_scheduler,
)
##############################################################################################################
# Let's delete EMA params for early steps to save some computes at training and inference
if self.step < config.ema_start_step:
self.generator_ema = None
self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
self.previous_time = None
def save(self):
print("Start gathering distributed model states...")
generator_state_dict = fsdp_state_dict(
self.model.generator)
critic_state_dict = fsdp_state_dict(
self.model.fake_score)
if self.config.ema_start_step < self.step:
state_dict = {
"generator": generator_state_dict,
"critic": critic_state_dict,
"generator_ema": self.generator_ema.state_dict(),
}
else:
state_dict = {
"generator": generator_state_dict,
"critic": critic_state_dict,
}
if self.is_main_process:
os.makedirs(os.path.join(self.output_path,
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
torch.save(state_dict, os.path.join(self.output_path,
f"checkpoint_model_{self.step:06d}", "model.pt"))
print("Model saved to", os.path.join(self.output_path,
f"checkpoint_model_{self.step:06d}", "model.pt"))
def fwdbwd_one_step(self, batch, train_generator):
self.model.eval() # prevent any randomness (e.g. dropout)
if self.step % 20 == 0:
torch.cuda.empty_cache()
# Step 1: Get the next batch of text prompts
text_prompts = batch["prompts"] # next(self.dataloader)
if "ode_latent" in batch:
clean_latent = batch["ode_latent"][:, -1].to(device=self.device, dtype=self.dtype)
else:
frames = batch["frames"].to(device=self.device, dtype=self.dtype)
with torch.no_grad():
clean_latent = self.model.vae.encode_to_latent(
frames).to(device=self.device, dtype=self.dtype)
image_latent = clean_latent[:, 0:1, ]
batch_size = len(text_prompts)
image_or_video_shape = list(self.config.image_or_video_shape)
image_or_video_shape[0] = batch_size
# Step 2: Extract the conditional infos
with torch.no_grad():
conditional_dict = self.model.text_encoder(
text_prompts=text_prompts)
if not getattr(self, "unconditional_dict", None):
unconditional_dict = self.model.text_encoder(
text_prompts=[self.config.negative_prompt] * batch_size)
unconditional_dict = {k: v.detach()
for k, v in unconditional_dict.items()}
self.unconditional_dict = unconditional_dict # cache the unconditional_dict
else:
unconditional_dict = self.unconditional_dict
mini_bs, full_bs = (
batch["mini_bs"],
batch["full_bs"],
)
# Step 3: Store gradients for the generator (if training the generator)
if train_generator:
gan_G_loss = self.model.generator_loss(
image_or_video_shape=image_or_video_shape,
conditional_dict=conditional_dict,
unconditional_dict=unconditional_dict,
clean_latent=clean_latent,
initial_latent=image_latent if self.config.i2v else None
)
loss_ratio = mini_bs * self.world_size / full_bs
total_loss = gan_G_loss * loss_ratio * self.loss_scale
total_loss.backward()
generator_grad_norm = self.model.generator.clip_grad_norm_(
self.max_grad_norm_generator)
generator_log_dict = {"generator_grad_norm": generator_grad_norm,
"gan_G_loss": gan_G_loss}
return generator_log_dict
else:
generator_log_dict = {}
# Step 4: Store gradients for the critic (if training the critic)
(gan_D_loss, r1_loss, r2_loss), critic_log_dict = self.model.critic_loss(
image_or_video_shape=image_or_video_shape,
conditional_dict=conditional_dict,
unconditional_dict=unconditional_dict,
clean_latent=clean_latent,
real_image_or_video=clean_latent,
initial_latent=image_latent if self.config.i2v else None
)
loss_ratio = mini_bs * dist.get_world_size() / full_bs
total_loss = (gan_D_loss + 0.5 * (r1_loss + r2_loss)) * loss_ratio * self.loss_scale
total_loss.backward()
critic_grad_norm = self.model.fake_score.clip_grad_norm_(
self.max_grad_norm_critic)
critic_log_dict.update({"critic_grad_norm": critic_grad_norm,
"gan_D_loss": gan_D_loss,
"r1_loss": r1_loss,
"r2_loss": r2_loss})
return critic_log_dict
def generate_video(self, pipeline, prompts, image=None):
batch_size = len(prompts)
sampled_noise = torch.randn(
[batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
)
video, _ = pipeline.inference(
noise=sampled_noise,
text_prompts=prompts,
return_latents=True
)
current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
return current_video
def train(self):
start_step = self.step
while True:
if self.step == self.discriminator_warmup_steps and self.discriminator_warmup_steps != 0:
print("Resetting critic optimizer")
del self.critic_optimizer
torch.cuda.empty_cache()
# Create new optimizers
self.critic_optimizer = torch.optim.AdamW(
self.critic_param_groups,
betas=(self.config.beta1_critic, self.config.beta2_critic)
)
# Update checkpointer references
self.checkpointer_critic.optimizer = self.critic_optimizer
# Check if we're in the discriminator warmup phase
self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
# Only update generator and critic outside the warmup phase
TRAIN_GENERATOR = not self.in_discriminator_warmup and self.step % self.config.dfake_gen_update_ratio == 0
# Train the generator (only outside warmup phase)
if TRAIN_GENERATOR:
self.model.fake_score.requires_grad_(False)
self.model.generator.requires_grad_(True)
self.generator_optimizer.zero_grad(set_to_none=True)
extras_list = []
for ii, mini_batch in enumerate(self.dataloader.next()):
extra = self.fwdbwd_one_step(mini_batch, True)
extras_list.append(extra)
generator_log_dict = merge_dict_list(extras_list)
self.generator_optimizer.step()
if self.generator_ema is not None:
self.generator_ema.update(self.model.generator)
else:
generator_log_dict = {}
# Train the critic/discriminator
if self.in_discriminator_warmup:
# During warmup, only allow gradient for discriminator params
self.model.generator.requires_grad_(False)
self.model.fake_score.requires_grad_(False)
# Enable gradient only for discriminator params
for name, param in self.model.fake_score.named_parameters():
if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
param.requires_grad_(True)
else:
# Normal training mode
self.model.generator.requires_grad_(False)
self.model.fake_score.requires_grad_(True)
self.critic_optimizer.zero_grad(set_to_none=True)
extras_list = []
batch = next(self.dataloader)
extra = self.fwdbwd_one_step(batch, False)
extras_list.append(extra)
critic_log_dict = merge_dict_list(extras_list)
self.critic_optimizer.step()
# Increment the step since we finished gradient update
self.step += 1
# If we just finished warmup, print a message
if self.is_main_process and self.step == self.discriminator_warmup_steps:
print(f"Finished discriminator warmup after {self.discriminator_warmup_steps} steps")
# Create EMA params (if not already created)
if (self.step >= self.config.ema_start_step) and \
(self.generator_ema is None) and (self.config.ema_weight > 0):
self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
# Save the model
if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
torch.cuda.empty_cache()
self.save()
torch.cuda.empty_cache()
# Logging
wandb_loss_dict = {
"generator_grad_norm": generator_log_dict["generator_grad_norm"],
"critic_grad_norm": critic_log_dict["critic_grad_norm"],
"real_logit": critic_log_dict["noisy_real_logit"],
"fake_logit": critic_log_dict["noisy_fake_logit"],
"r1_loss": critic_log_dict["r1_loss"],
"r2_loss": critic_log_dict["r2_loss"],
}
if TRAIN_GENERATOR:
wandb_loss_dict.update({
"generator_grad_norm": generator_log_dict["generator_grad_norm"],
})
self.all_gather_dict(wandb_loss_dict)
wandb_loss_dict["diff_logit"] = wandb_loss_dict["real_logit"] - wandb_loss_dict["fake_logit"]
wandb_loss_dict["reg_loss"] = 0.5 * (wandb_loss_dict["r1_loss"] + wandb_loss_dict["r2_loss"])
if self.is_main_process:
if self.in_discriminator_warmup:
warmup_status = f"[WARMUP {self.step}/{self.discriminator_warmup_steps}] Training only discriminator params"
print(warmup_status)
if not self.disable_wandb:
wandb_loss_dict.update({"warmup_status": 1.0})
if not self.disable_wandb:
wandb.log(wandb_loss_dict, step=self.step)
if self.step % self.config.gc_interval == 0:
if dist.get_rank() == 0:
logging.info("DistGarbageCollector: Running GC.")
gc.collect()
torch.cuda.empty_cache()
if self.is_main_process:
current_time = time.time()
if self.previous_time is None:
self.previous_time = current_time
else:
if not self.disable_wandb:
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
self.previous_time = current_time
def all_gather_dict(self, target_dict):
for key, value in target_dict.items():
gathered_value = torch.zeros(
[self.world_size, *value.shape],
dtype=value.dtype, device=self.device)
dist.all_gather_into_tensor(gathered_value, value)
avg_value = gathered_value.mean().item()
target_dict[key] = avg_value