import argparse import contextlib import logging import math import os import time import matplotlib import mup import numpy as np import torch import torchvision.transforms.functional as transforms_f from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from einops import rearrange from lpips import lpips from torch.utils.data import DataLoader from tqdm.auto import tqdm import transformers import traceback from transformers import ( default_data_collator, get_scheduler, ) from collections import defaultdict from data import RawTokenDataset, get_maskgit_collator from common.eval_utils import decode_tokens, compute_lpips from genie.st_mask_git import STMaskGIT from genie.config import GenieConfig from visualize import decode_latents_wrapper from skimage import metrics as image_metrics from matplotlib import pyplot as plt from datetime import datetime from accelerate import DistributedDataParallelKwargs torch.autograd.set_detect_anomaly(True) # Get current date and time now = datetime.now() # Format the datetime object as a string formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") torch.set_float32_matmul_precision("medium") logger = get_logger(__name__) def parse_args(): # parser = argparse.ArgumentParser(description="Train a MaskGIT or Llama-style LLM on video generation.") parser = argparse.ArgumentParser(description="Train a spatial-temporal MaskGIT-style model on video generation.") # Data parser.add_argument( "--train_data_dir", type=str, default="data/1x_humanoid_magvit_traj1000_train", help="Directory containing tokenized data, should have a `video.bin`, `metadata.json` and `segment_ids.json`." ) parser.add_argument( "--val_data_dir", type=str, default="data/1x_humanoid_magvit_traj1000_val", help="Directory containing tokenized data, should have a `video.bin`, `metadata.json` and `segment_ids.json`." ) parser.add_argument( "--domain", type=str, default="1x_humanoid", help="The domain name for the dataset" ) parser.add_argument( "--window_size", type=int, default=12, help="Number of frames to in a sequence.", ) parser.add_argument( "--stride", type=int, default=None, help="Difference in frame count between consecutive frames in a sequence.", ) parser.add_argument( "--filter_overlaps", action="store_true", help=( "Whether to filter repeated frames in the train dataset (`filter_overlaps` always true for the val set). " "Filtering essentially makes the training dataset less correlated but ~16x smaller, " "see the `filter_overlaps` argument in `RawTokenDataset` for details."), default=True ) # Model parser.add_argument( "--llama_config", type=str, help="`transformers.LlamaConfig` json. " "E.g. https://huggingface.co/1x-technologies/Llama_1B_v0/blob/main/config.json", ) parser.add_argument( "--diffusion", action="store_true", help="use diffusion model." ), parser.add_argument( "--genie_config", type=str, help="GenieConfig json." ), parser.add_argument( "--warmstart_path", type=str, default=None, help="A path to a checkpoint to warmstart a model from, possibly not trained on the same dataset, " "will resize embeddings if needed.", ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help="If the training should continue from a checkpoint folder.", ) # Training parser.add_argument( "--per_device_train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--per_device_eval_batch_size", type=int, default=1, help="Batch size (per device) for the evaluation dataloader.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", default=False, action="store_true", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay to use.") parser.add_argument("--num_train_epochs", type=int, default=2, help="Total number of training epochs to perform.") parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--max_eval_steps", type=int, default=int(1e10), help="Only evaluate on `max_eval_steps` batches of validation data per process, faster.", ) parser.add_argument( "--eval_every_n_steps", type=int, default=1000, help="Eval every N training steps.", ) parser.add_argument( "--vis_every_n_steps", type=int, default=20000, help="Visualize every N training steps.", ) parser.add_argument( "--lr_scheduler_type", type=str, default="constant_with_warmup", help="The scheduler type to use.", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "custom_cosine"], ) parser.add_argument( "--num_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--max_grad_norm", type=float, default=1.0, help="Threshold to clip gradients.", ) parser.add_argument( "--attention_dropout", type=float, default=0.05, help="Attention dropout prob.", ) parser.add_argument( "--adam_beta_1", type=float, default=0.9, ) parser.add_argument( "--adam_beta_2", type=float, default=0.95, ) parser.add_argument( "--adam_eps", type=float, default=1e-8, ) # Misc parser.add_argument("--output_dir", type=str, required=True, help="Where to store the model checkpoints.") parser.add_argument( "--checkpointing_steps", type=str, default="10000", help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument( "--overfit_first_batch", action="store_true", help=( "Debug option that trains and validates on only the first batch of the training dataset." ), ) parser.add_argument( "--report_to", type=str, default="wandb", help="The integration to report the results and logs to.", ) parser.add_argument( "--mu_transfer", action="store_true", help="If specified, will train with mu transfer reparametrizations. Only supports Llama models.", default=True ) parser.add_argument( "--no_compile", action="store_true", help="If specified, will not compile the model.", default=True ) parser.add_argument( "--run_name", type=str, default="video_prediction", help="", ) parser.add_argument( "--cleanup_checkpoints", action="store_true", help=( "Whether to clean up checkpoints (to keep only the last 3) along the training. "), ) parser.add_argument( "--save_second_epoch", action="store_true", help="Whether to checkpoint at the end of the second epoch (1-indexing). This one will not be auto-deleted by cleanup.", default=True ) return parser def save_checkpoint(model, accelerator, args, filename): """ filename: `save_path = os.path.join(args.output_dir, filename)` """ unwrapped_model = accelerator.unwrap_model(model) save_path = os.path.join(args.output_dir, filename) if accelerator.is_main_process: unwrapped_model.save_pretrained( save_path, is_main_process=accelerator.is_main_process, save_function=accelerator.save ) accelerator.save_state(save_path) @torch.no_grad() def visualize(accelerator, model, dataloader, window_size, metrics_prefix="train", max_steps=1): """ Visualizes model's autoregressive generation outputs, logged to wandb. It uses teacher-forcing (causal in time axis) """ accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) if not unwrapped_model.config.jointly_predict_states: return metrics = defaultdict(list) if accelerator.is_main_process: lpips_alex = lpips.LPIPS(net="alex") # Calculate LPIPS w/ AlexNet, the fastest option decode_latents = decode_latents_wrapper() # re-initializing every time to save memory unwrapped_model.eval() rank = 0 dataloader_iter = iter(dataloader) for step in range(len(dataloader)): try: batch = next(dataloader_iter) # Note: hardcoding 4 image cap for faster inference on small models TEST_NUM = 4 reshaped_labels = rearrange(batch["labels"][:TEST_NUM], "b (t s) -> b t s", t=window_size).to(accelerator.device) # `s` is really `(h, w)` domains = batch["domain"][:TEST_NUM] if 'action_ids' in batch: action_ids = batch["action_ids"][:TEST_NUM].to(accelerator.device) else: action_ids = None # hardcoding half of frames for context num_prompt_frames = unwrapped_model.config.num_prompt_frames num_new_tokens = batch["w"][0] * batch["h"][0] * (window_size - num_prompt_frames) prompt_input_ids = rearrange(reshaped_labels[:, :num_prompt_frames], "b t s -> b (t s)") outputs = unwrapped_model.generate(input_ids=prompt_input_ids, attention_mask=torch.ones_like(prompt_input_ids), max_new_tokens=num_new_tokens, min_new_tokens=num_new_tokens, action_ids=action_ids, domain=batch["domain"][:TEST_NUM], w=batch["w"][:TEST_NUM], h=batch["h"][:TEST_NUM]) output_tokens = rearrange(outputs, "b (t h w) -> b t h w", t=window_size, h=batch["h"][0], w=batch["w"][0]) gtruth_tokens = rearrange(reshaped_labels[:, num_prompt_frames:], "b t (h w) -> b t h w", h=batch["h"][0], w=batch["w"][0]) decoded_output = decode_tokens(output_tokens.cpu(), decode_latents) decoded_gtruth = decode_tokens(gtruth_tokens.cpu(), decode_latents) decoded_output = accelerator.gather(decoded_output.to(accelerator.device)).cpu() decoded_gtruth = accelerator.gather(decoded_gtruth.to(accelerator.device)).cpu() # As in Genie. we also compute psnr_delta = PSNR(x_t, x_t_hat) - PSNR(x_t, x_t_hatprime) where x_t_hatprime samples random actions # this difference in PSNR measures the controllability # actions need to be just uniform random actions if action_ids is not None: random_action_ids = torch.randn_like(action_ids) random_action_outputs = unwrapped_model.generate(input_ids=prompt_input_ids, attention_mask=torch.ones_like(prompt_input_ids), max_new_tokens=num_new_tokens, min_new_tokens=num_new_tokens, action_ids=random_action_ids, domain=batch["domain"][:TEST_NUM], w=batch["w"][:TEST_NUM], h=batch["h"][:TEST_NUM], skip_normalization=True) random_output_tokens = rearrange(random_action_outputs, "b (t h w) -> b t h w", t=window_size, h=batch["h"][0], w=batch["w"][0]) random_output_tokens = decode_tokens(random_output_tokens.cpu(), decode_latents) random_output_tokens = accelerator.gather(random_output_tokens.to(accelerator.device)).cpu() random_pred_frames_numpy = random_output_tokens[:, num_prompt_frames:].detach().cpu().numpy() if accelerator.is_main_process: exs_per_fig = 4 for j in range(0, len(decoded_output), exs_per_fig): fig, axs = plt.subplots(nrows=2 * exs_per_fig, ncols=window_size, figsize=(3 * window_size, 3 * 2 * exs_per_fig)) # If len(decoded_output) is not a multiple of 4, make sure to truncate properly for k in range(min(exs_per_fig, len(decoded_output) - j)): for i in range(num_prompt_frames): for ax in (axs[k * 2, i], axs[k * 2 + 1, i]): ax.imshow(transforms_f.to_pil_image(decoded_output[j + k, i])) ax.set_title("Context") ax.axis("off") for i in range(num_prompt_frames, window_size): axs[k * 2, i].imshow(transforms_f.to_pil_image(decoded_gtruth[j + k, i - num_prompt_frames])) axs[k * 2, i].set_title("Ground truth") axs[k * 2 + 1, i].imshow(transforms_f.to_pil_image(decoded_output[j + k, i])) axs[k * 2 + 1, i].set_title("Prediction") for ax in axs[:, i]: ax.axis("off") rank = accelerator.process_index wandb_tracker = accelerator.get_tracker("wandb") # wandb_tracker.log({f"vis_{metrics_prefix}_{j}": fig}, commit=False) wandb_tracker.log({f"{domains[0]}/vis_{metrics_prefix}_{j}": fig}, commit=False) plt.close(fig) metrics["ar_lpips"].extend(compute_lpips(decoded_gtruth, # Note: not parallelizing right now decoded_output[:, num_prompt_frames:], lpips_alex)) gt_frames_numpy = decoded_gtruth.detach().cpu().numpy() pred_frames_numpy = decoded_output[:, num_prompt_frames:].detach().cpu().numpy() psnr = [image_metrics.peak_signal_noise_ratio( gt_frames_numpy[i] / 255., pred_frames_numpy[i] / 255., data_range=1.0) for i in range(gt_frames_numpy.shape[0])] ssim = [np.mean([image_metrics.structural_similarity( gt_frames_numpy[i][j] / 255., pred_frames_numpy[i][j] / 255., data_range=1.0, channel_axis=0) \ for i in range(gt_frames_numpy.shape[0])]) for j in range(gt_frames_numpy.shape[1])] # compute some other metrics metrics[f"{metrics_prefix}/ar_psnr"].extend(psnr) metrics[f"{metrics_prefix}/ar_ssim"].extend(ssim) metrics[f"{batch['domain'][0]}/ar_lpips"].extend(compute_lpips(decoded_gtruth, # Note: not parallelizing right now decoded_output[:, num_prompt_frames:], lpips_alex)) if action_ids is not None: # log controllability as random subtracts groundtruth psnr_delta = [psnr[i] - image_metrics.peak_signal_noise_ratio( gt_frames_numpy[i] / 255., random_pred_frames_numpy[i] / 255., data_range=1.0) for i in range(gt_frames_numpy.shape[0])] metrics[f"{metrics_prefix}/ar_psnr_delta"].extend(psnr_delta) except Exception as e: print("batch failed", traceback.format_exc()) if step + 1 >= max_steps: break unwrapped_model.train() if accelerator.is_main_process: metrics = {f"{metrics_prefix}_{key}": np.mean(val) for key, val in metrics.items() if len(val) > 0} print(f"{metrics=}") wandb_tracker = accelerator.get_tracker("wandb") wandb_tracker.log(metrics, commit=False) def train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args): total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 starting_epoch = 0 resume_step = None checkpoint_path = "" # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: try: if os.path.exists(args.resume_from_checkpoint + "/pytorch_model.bin"): checkpoint_path = args.resume_from_checkpoint path = os.path.basename(args.resume_from_checkpoint.rstrip("/")) # else: # checkpoint_path = args.resume_from_checkpoint # path = os.path.basename(args.resume_from_checkpoint.rstrip("/")) else: # Get the most recent checkpoint base_path = os.path.dirname(args.resume_from_checkpoint) dirs = [os.path.join(base_path, f.name) for f in os.scandir(base_path) if f.is_dir()] dirs.sort(key=os.path.getctime) # Sorts folders by date modified, most recent checkpoint is the last if len(dirs) > 0: path = dirs[-1] checkpoint_path = path path = os.path.basename(checkpoint_path) accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") if os.path.exists(checkpoint_path): # for finetuning with a different structures print(f"loading checkpoint from {checkpoint_path}") accelerator.load_state(checkpoint_path, strict=False) # tied weights not saved so can't load strict, but also no need to tie again # Extract `epoch_{i}` or `step_{i}` training_difference = os.path.splitext(path)[0] else: print("No checkpoint found, training from scratch.") training_difference = "step_0" if "epoch" in training_difference: starting_epoch = int(training_difference.replace("epoch_", "")) + 1 resume_step = None num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) completed_steps = starting_epoch * num_update_steps_per_epoch else: # need to multiply `gradient_accumulation_steps` to reflect real steps resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) completed_steps = resume_step // args.gradient_accumulation_steps resume_step -= starting_epoch * len(train_dataloader) except Exception as e: training_difference = "step_0" starting_epoch = 0 completed_steps = 0 print("load checkpoint incomplete", traceback.format_exc()) # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) loss_info = torch.zeros(2, device=accelerator.device) # sum, count for epoch in range(starting_epoch, args.num_train_epochs): model.train() train_dataloader.set_epoch(epoch) # potentially cleanup the previous checkpoints if args.cleanup_checkpoints: if os.path.exists(args.output_dir): dirs = [os.path.join(args.output_dir, f.name) for f in os.scandir(args.output_dir) if f.is_dir()] if len(dirs) > 3: # must keep at least 2 checkpoints for second epoch and most recent one if args.save_second_epoch and os.path.join(args.output_dir, "epoch_1") in dirs: # never prune second epoch dirs.remove(os.path.join(args.output_dir, "epoch_1")) dirs.sort(key=os.path.getctime) paths = dirs[:-3] # only keep the last 3 # for path in paths: # print(f"remove rm -rf {path}") # os.system(f"rm -rf {path}") if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) else: active_dataloader = train_dataloader _time = time.time() dataloader_iter = iter(active_dataloader) # Switch back to train mode model.train() num_iters_per_epoch = max(len(active_dataloader) - 8, 1) # avoid the last few iters for step in range(num_iters_per_epoch): try: train_action_loss = 0 batch = next(dataloader_iter) # to reduce the numerical instability in the very beginning of training gradient_accumulation_steps = args.gradient_accumulation_steps batch_size = batch["input_ids"].size(0) # Manual gradient accumulation because accelerator somehow taking a lot of memory is_update_step = (step + 1) % gradient_accumulation_steps == 0 ctx_manager = contextlib.nullcontext() if is_update_step else accelerator.no_sync(model) with ctx_manager: accelerator.wait_for_everyone() outputs = model(**batch) loss = outputs.loss if not torch.isnan(loss).any(): loss_info[0] += loss.detach().mean() * batch_size # only video loss if "action_loss" in outputs: train_action_loss = outputs.action_loss.item() loss += config.action_loss_weight * outputs.action_loss loss_info[1] += batch_size accelerator.backward(loss / gradient_accumulation_steps) else: print("Warning: NaN or Inf detected in loss. Skipping backward pass.") dummy_loss = torch.zeros_like(loss, requires_grad=True) accelerator.backward(dummy_loss) if not is_update_step: continue except Exception as e: # avoid final iteration batch concatenation problems print("batch failed", traceback.format_exc()) continue # Everything below only happens on update step if args.max_grad_norm is not None: accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() loss_info = accelerator.reduce(loss_info) avg_train_loss = (loss_info[0] / loss_info[1]).item() # sum / count loss_info *= 0 # reset sum and count try: perplexity = math.exp(avg_train_loss) except OverflowError: print("overflow error for perplexity") perplexity = float("inf") # print(f"{perplexity=} {avg_train_loss=}") batch_time = time.time() - _time # accumulated batch rank = accelerator.process_index domain_iter = str(batch['domain'][0]) _time = time.time() accelerator.log( { "train_perplexity": perplexity, "train_loss": avg_train_loss, "train_action_loss": train_action_loss, f"stat/{domain_iter}_action_loss": train_action_loss / loss_info[1], f"stat/{domain_iter}_train_perplexity": perplexity, f"stat/{domain_iter}_train_loss": avg_train_loss, "epoch": epoch, "update_step": completed_steps, "examples_processed": completed_steps * args.per_device_train_batch_size * args.gradient_accumulation_steps * accelerator.num_processes, "learning_rate": lr_scheduler.get_last_lr()[0], "flops": (completed_steps + 1) * experiment_config["FLOPs_per_update_step"], "throughput_examples": experiment_config["effective_batch_size"] / batch_time, }, step=completed_steps) progress_bar.update(1) completed_steps += 1 # print(f"{completed_steps % args.checkpointing_steps=} {completed_steps=} {args.checkpointing_steps=}") if completed_steps % int(args.checkpointing_steps) == 0: print(f"Saving checkpoint at step {completed_steps}!") save_checkpoint(model, accelerator, args, f"step_{completed_steps}") if completed_steps % args.eval_every_n_steps == 0: time.sleep(1) # manual adding time sleep model.eval() eval_losses = [] # Compute token-level accuracy (w/ teacher forcing) num_correct = 0 num_total = 0 # barrier # to resolve the data collating issues eval_dataloader_iter = iter(eval_dataloader) for step in range(args.max_eval_steps): eval_action_loss = 0 try: batch = next(eval_dataloader_iter) batch_size = len(batch["input_ids"]) # Last batch might not be full with torch.no_grad(): outputs = model(**batch) loss = outputs.loss if "action_loss" in outputs: eval_action_loss = outputs.action_loss.item() eval_losses.append(accelerator.gather_for_metrics(loss.repeat(batch_size))) except Exception as e: print("error:", e) continue if "acc" in outputs: # `num_correct` and `num_total` actually track mean accuracy in this case. num_correct_batch = accelerator.reduce(outputs.acc, reduction="mean").item() * batch_size num_total_batch = batch_size num_correct += num_correct_batch num_total += num_total_batch else: shifted_preds = torch.argmax(outputs.logits[:, :-1, :], dim=-1) shifted_labels = batch["labels"][:, 1:] num_correct_batch = accelerator.gather_for_metrics((shifted_preds == shifted_labels).sum()).sum().item() num_total_batch = accelerator.gather_for_metrics(torch.tensor(torch.numel(shifted_labels), device=accelerator.device)).sum().item() num_correct += num_correct_batch num_total += num_total_batch if step >= args.max_eval_steps * args.num_datasets: break try: accelerator.log( { f'stat/{domain_iter}_eval_teacher_acc': num_correct_batch / num_total_batch, f'stat/{domain_iter}_eval_loss': (torch.mean(eval_losses[-1])).item(), f'stat/{domain_iter}_eval_action_loss': eval_action_loss, }, step=completed_steps, ) except Exception as e: print("log failed", e) continue if len(eval_losses) > 0: eval_losses = torch.cat(eval_losses) eval_loss = torch.mean(eval_losses).item() eval_teacher_acc = num_correct / num_total try: perplexity = math.exp(eval_loss) except OverflowError: print("overflow error for perplexity") perplexity = float("inf") else: continue logger.info(f"{completed_steps=} {perplexity=} {eval_loss=} {eval_teacher_acc=}") accelerator.log( { "eval_perplexity": perplexity, "eval_loss": eval_loss, "eval_action_loss": eval_action_loss, "eval_teacher_acc": eval_teacher_acc, "epoch": epoch, "update_step": completed_steps, "examples_processed": completed_steps * args.per_device_train_batch_size * args.gradient_accumulation_steps * accelerator.num_processes, "flops": completed_steps * experiment_config["FLOPs_per_update_step"], }, step=completed_steps, ) if completed_steps % args.vis_every_n_steps == 0 or completed_steps >= args.max_train_steps: if "encoder_type" not in experiment_config: experiment_config["encoder_name_or_path"] = "data/magvit2.ckpt" experiment_config["encoder_type"] = "magvit" if not args.overfit_first_batch: # val is same as train otherwise visualize(accelerator, model, eval_dataloader, args.window_size, "val") visualize(accelerator, model, train_dataloader, args.window_size, "train") if completed_steps >= args.max_train_steps: break if args.checkpointing_steps == "epoch" or (args.save_second_epoch and epoch == 1): save_checkpoint(model, accelerator, args, f"epoch_{epoch}") save_checkpoint(model, accelerator, args, f"final_checkpt") accelerator.end_training() def main(): parser = parse_args() args = parser.parse_args() assert (args.llama_config is not None) ^ (args.genie_config is not None), \ "Exactly one of `llama_config` and `genie_config` should be set." # Manual gradient accumulation ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # accelerator = Accelerator(gradient_accumulation_steps=1, log_with=args.report_to, project_dir=args.output_dir, kwargs_handlers=[ddp_kwargs]) accelerator.init_trackers("video") if accelerator.is_main_process: accelerator.trackers[0].run.name = formatted_date + "_" + args.run_name # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_info() print(f"Rank {accelerator.process_index} assigned to device {torch.cuda.current_device()}") else: transformers.utils.logging.set_verbosity_error() if args.seed is not None: set_seed(args.seed) if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) accelerator.wait_for_everyone() config = GenieConfig.from_pretrained(args.genie_config) train_dataset = RawTokenDataset(args.train_data_dir, window_size=args.window_size, name=args.domain, stride=args.stride, filter_overlaps=args.filter_overlaps, compute_stride_from_freq_table=(args.stride is None), use_actions=config.use_actions) if not args.overfit_first_batch: eval_dataset = RawTokenDataset(args.val_data_dir, window_size=args.window_size, name=args.domain, stride=args.stride, filter_overlaps=True, compute_stride_from_freq_table=(args.stride is None), use_actions=config.use_actions) else: train_dataset.valid_start_inds = train_dataset.valid_start_inds[:args.per_device_train_batch_size * args.gradient_accumulation_steps * accelerator.num_processes] eval_dataset = train_dataset assert all(train_dataset.metadata[shared_key] == eval_dataset.metadata[shared_key] for shared_key in ("s", "vocab_size", "hz")) latent_side_len, vocab_size, hz = [train_dataset.metadata[key] for key in ("s", "vocab_size", "hz")] # Note: changing this may affect pre-trained model due to attn scaling config.use_mup = args.mu_transfer config.image_vocab_size = vocab_size config.T = args.window_size model = STMaskGIT(config) if config.use_actions: print(f"Initializing action projectors with {train_dataset.n_action}d action") model.init_action_projectors([train_dataset.name], [train_dataset.n_action], [train_dataset.action_stat], config.action_network) if args.mu_transfer: model.set_mup_shapes(rescale_params=True) # model.init_weights() # might be unnecessary if `rescale_params` is True # Optimizer. Split weights in two groups, one with weight decay and the other not. no_decay = ["bias", "layer_norm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] # scale base learning rate effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \ * accelerator.num_processes args.learning_rate = args.learning_rate * min(max(1, effective_batch_size / 64), 8) opt_class = mup.MuAdamW if args.mu_transfer else torch.optim.AdamW optimizer = opt_class(optimizer_grouped_parameters, lr=args.learning_rate, betas=(args.adam_beta_1, args.adam_beta_2), eps=args.adam_eps) # DataLoaders creation: collate_fn = default_data_collator if args.llama_config is not None else get_maskgit_collator(config) train_dataloader = DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size, num_workers=8, pin_memory=True, ) # Shuffle eval dataset and then set shuffle=False on the dataloader. # Shuffling in the dataloader results in reshuffling with each iteration. eval_dataset.valid_start_inds = torch.tensor(eval_dataset.valid_start_inds)[ torch.randperm(len(eval_dataset), generator=torch.Generator().manual_seed(0)) ].tolist() eval_dataloader = DataLoader( eval_dataset, shuffle=False, collate_fn=collate_fn, batch_size=args.per_device_eval_batch_size, pin_memory=True, num_workers=8, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch if args.max_train_steps < 2000 and args.resume_from_checkpoint is None: # minimal number of trainng steps args.max_train_steps = 2000 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) overrode_max_train_steps = True if args.lr_scheduler_type == "custom_cosine": # decay to `end_ratio` of the peak learning rate def get_lr_wrapper(warmup_steps, max_steps, end_ratio=0.1): def get_lr(step): if step < warmup_steps: return (step + 1) / warmup_steps remaining_steps = max_steps - warmup_steps return ((1 + math.cos(math.pi * (step - warmup_steps) / remaining_steps)) / 2) \ * (1 - end_ratio) + end_ratio return get_lr lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, get_lr_wrapper(args.num_warmup_steps * accelerator.num_processes, args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes) ) else: lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes, ) # Enable gradient checkpointing to save memory if args.gradient_checkpointing: logger.info("Enabling gradient checkpointing") model.gradient_checkpointing_enable() model.config.use_cache = False # incompatible with grad checkpointing # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) if not args.no_compile: torch._dynamo.config.cache_size_limit = 256 torch._dynamo.config.optimize_ddp = False # https://github.com/pytorch/pytorch/issues/104674 # TODO: https://github.com/pytorch/pytorch/issues/109774#issuecomment-2046633776 model = torch.compile(model) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Figure out how many steps we should save the Accelerator states checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps) # We need to initialize the trackers we use, and also store our configuration. # The trackers initialize automatically on the main process. experiment_config = vars(args) | vars(config) seq_len = latent_side_len**2 * args.window_size args.num_datasets = 1 model_module = model.module if hasattr(model, "module") else model experiment_config.update({ "model_parameters": sum(p.numel() for p in model_module.parameters()), "model_parameters_M": round(sum(p.numel() for p in model_module.parameters()) / 1e6), "trunk_parameters": sum(p.numel() for p in model_module.decoder.parameters()), "trunk_parameters_M": round(sum(p.numel() for p in model_module.decoder.parameters()) / 1e6), "seq_len": seq_len, "hz": hz / train_dataset.stride, "train_data_tokens": len(train_dataset) * seq_len, # only one epoch "effective_batch_size": effective_batch_size, "effective_batch_size_tokens": effective_batch_size * seq_len, "mixed_precision": accelerator.mixed_precision, "num_datasets": 1 }) print("============================") print(f"model parameters: {experiment_config['model_parameters_M']}M") print("============================") experiment_config["FLOPs_per_update_step"] = 6 * experiment_config["model_parameters"] \ * experiment_config["effective_batch_size_tokens"] accelerator.init_trackers(project_name="video", config=experiment_config) train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args) if __name__ == "__main__": main()