import gc import logging from utils.dataset import ODERegressionLMDBDataset, cycle from model import ODERegression from collections import defaultdict from utils.misc import ( set_seed ) import torch.distributed as dist from omegaconf import OmegaConf import torch import wandb import time import os from utils.distributed import barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job class Trainer: def __init__(self, config): self.config = config self.step = 0 # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True launch_distributed_job() global_rank = dist.get_rank() self.world_size = dist.get_world_size() self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32 self.device = torch.cuda.current_device() self.is_main_process = global_rank == 0 self.disable_wandb = config.disable_wandb # use a random seed for the training if config.seed == 0: random_seed = torch.randint(0, 10000000, (1,), device=self.device) dist.broadcast(random_seed, src=0) config.seed = random_seed.item() set_seed(config.seed + global_rank) if self.is_main_process and not self.disable_wandb: wandb.login(host=config.wandb_host, key=config.wandb_key) wandb.init( config=OmegaConf.to_container(config, resolve=True), name=config.config_name, mode="online", entity=config.wandb_entity, project=config.wandb_project, dir=config.wandb_save_dir ) self.output_path = config.logdir # Step 2: Initialize the model and optimizer assert config.distribution_loss == "ode", "Only ODE loss is supported for ODE training" self.model = ODERegression(config, device=self.device) self.model.generator = fsdp_wrap( self.model.generator, sharding_strategy=config.sharding_strategy, mixed_precision=config.mixed_precision, wrap_strategy=config.generator_fsdp_wrap_strategy ) self.model.text_encoder = fsdp_wrap( self.model.text_encoder, sharding_strategy=config.sharding_strategy, mixed_precision=config.mixed_precision, wrap_strategy=config.text_encoder_fsdp_wrap_strategy, cpu_offload=getattr(config, "text_encoder_cpu_offload", False) ) if not config.no_visualize or config.load_raw_video: self.model.vae = self.model.vae.to( device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32) self.generator_optimizer = torch.optim.AdamW( [param for param in self.model.generator.parameters() if param.requires_grad], lr=config.lr, betas=(config.beta1, config.beta2), weight_decay=config.weight_decay ) # Step 3: Initialize the dataloader dataset = ODERegressionLMDBDataset( config.data_path, max_pair=getattr(config, "max_pair", int(1e8))) sampler = torch.utils.data.distributed.DistributedSampler( dataset, shuffle=True, drop_last=True) dataloader = torch.utils.data.DataLoader( dataset, batch_size=config.batch_size, sampler=sampler, num_workers=8) total_batch_size = getattr(config, "total_batch_size", None) if total_batch_size is not None: assert total_batch_size == config.batch_size * self.world_size, "Gradient accumulation is not supported for ODE training" self.dataloader = cycle(dataloader) self.step = 0 ############################################################################################################## # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts if getattr(config, "generator_ckpt", False): print(f"Loading pretrained generator from {config.generator_ckpt}") state_dict = torch.load(config.generator_ckpt, map_location="cpu")[ 'generator'] self.model.generator.load_state_dict( state_dict, strict=True ) ############################################################################################################## self.max_grad_norm = 10.0 self.previous_time = None def save(self): print("Start gathering distributed model states...") generator_state_dict = fsdp_state_dict( self.model.generator) state_dict = { "generator": generator_state_dict } if self.is_main_process: os.makedirs(os.path.join(self.output_path, f"checkpoint_model_{self.step:06d}"), exist_ok=True) torch.save(state_dict, os.path.join(self.output_path, f"checkpoint_model_{self.step:06d}", "model.pt")) print("Model saved to", os.path.join(self.output_path, f"checkpoint_model_{self.step:06d}", "model.pt")) def train_one_step(self): VISUALIZE = self.step % 100 == 0 self.model.eval() # prevent any randomness (e.g. dropout) # Step 1: Get the next batch of text prompts batch = next(self.dataloader) text_prompts = batch["prompts"] ode_latent = batch["ode_latent"].to( device=self.device, dtype=self.dtype) # Step 2: Extract the conditional infos with torch.no_grad(): conditional_dict = self.model.text_encoder( text_prompts=text_prompts) # Step 3: Train the generator generator_loss, log_dict = self.model.generator_loss( ode_latent=ode_latent, conditional_dict=conditional_dict ) unnormalized_loss = log_dict["unnormalized_loss"] timestep = log_dict["timestep"] if self.world_size > 1: gathered_unnormalized_loss = torch.zeros( [self.world_size, *unnormalized_loss.shape], dtype=unnormalized_loss.dtype, device=self.device) gathered_timestep = torch.zeros( [self.world_size, *timestep.shape], dtype=timestep.dtype, device=self.device) dist.all_gather_into_tensor( gathered_unnormalized_loss, unnormalized_loss) dist.all_gather_into_tensor(gathered_timestep, timestep) else: gathered_unnormalized_loss = unnormalized_loss gathered_timestep = timestep loss_breakdown = defaultdict(list) stats = {} for index, t in enumerate(timestep): loss_breakdown[str(int(t.item()) // 250 * 250)].append( unnormalized_loss[index].item()) for key_t in loss_breakdown.keys(): stats["loss_at_time_" + key_t] = sum(loss_breakdown[key_t]) / \ len(loss_breakdown[key_t]) self.generator_optimizer.zero_grad() generator_loss.backward() generator_grad_norm = self.model.generator.clip_grad_norm_( self.max_grad_norm) self.generator_optimizer.step() # Step 4: Visualization if VISUALIZE and not self.config.no_visualize and not self.config.disable_wandb and self.is_main_process: # Visualize the input, output, and ground truth input = log_dict["input"] output = log_dict["output"] ground_truth = ode_latent[:, -1] input_video = self.model.vae.decode_to_pixel(input) output_video = self.model.vae.decode_to_pixel(output) ground_truth_video = self.model.vae.decode_to_pixel(ground_truth) input_video = 255.0 * (input_video.cpu().numpy() * 0.5 + 0.5) output_video = 255.0 * (output_video.cpu().numpy() * 0.5 + 0.5) ground_truth_video = 255.0 * (ground_truth_video.cpu().numpy() * 0.5 + 0.5) # Visualize the input, output, and ground truth wandb.log({ "input": wandb.Video(input_video, caption="Input", fps=16, format="mp4"), "output": wandb.Video(output_video, caption="Output", fps=16, format="mp4"), "ground_truth": wandb.Video(ground_truth_video, caption="Ground Truth", fps=16, format="mp4"), }, step=self.step) # Step 5: Logging if self.is_main_process and not self.disable_wandb: wandb_loss_dict = { "generator_loss": generator_loss.item(), "generator_grad_norm": generator_grad_norm.item(), **stats } wandb.log(wandb_loss_dict, step=self.step) if self.step % self.config.gc_interval == 0: if dist.get_rank() == 0: logging.info("DistGarbageCollector: Running GC.") gc.collect() def train(self): while True: self.train_one_step() if (not self.config.no_save) and self.step % self.config.log_iters == 0: self.save() torch.cuda.empty_cache() barrier() if self.is_main_process: current_time = time.time() if self.previous_time is None: self.previous_time = current_time else: if not self.disable_wandb: wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step) self.previous_time = current_time self.step += 1