import argparse import json import os import random import uuid from contextlib import nullcontext from copy import deepcopy from datetime import datetime as dt from functools import partial from math import log2 from time import sleep, time from typing import Any, Dict import git import numpy as np import psutil import torch import torch.nn as nn import torch.utils.data.distributed import wandb from PIL import Image from torch import distributed as dist from torch import optim from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from tqdm import tqdm import unik3d.datasets as datasets from unik3d.datasets import (ConcatDataset, DistributedSamplerNoDuplicate, collate_fn, get_weights) from unik3d.models import UniK3D from unik3d.ops.scheduler import CosineScheduler from unik3d.utils import (barrier, format_seconds, is_main_process, log_train_artifacts, validate) from unik3d.utils.distributed import (create_local_process_group, local_broadcast_process_authkey, setup_multi_processes, setup_slurm, sync_string_across_gpus, sync_tensor_across_gpus) from unik3d.utils.ema_torch import (DummyExponentialMovingAverage, ExponentialMovingAverage) from unik3d.utils.misc import calculate_mean_values EMA_INTERVAL = 10 EMA_TAU = 10000 EMA_START = 50000 MAP_DTYPE = { "f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32, } def aggregate_sync_losses(dict_: dict[str, torch.Tensor], device): keys = list(dict_.keys()) values = torch.tensor(list(dict_.values()), device=device) keys = sync_string_across_gpus(keys, device) values = sync_tensor_across_gpus(values, dim=0).cpu().tolist() dict_ = calculate_mean_values(keys, values) return dict_ def main_worker(config: Dict[str, Any], args: argparse.Namespace): current_process = psutil.Process(os.getpid()) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") seed = config["generic"]["seed"] if not args.distributed: args.rank = 0 args.local_rank = 0 args.world_size = 1 else: # initializes the distributed backend which will take care of synchronizing nodes/GPUs setup_multi_processes(config) is_slurm = "SLURM_PROCID" in os.environ if is_slurm: setup_slurm("nccl", port=args.master_port) args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.local_rank = device = int(os.environ["LOCAL_RANK"]) if not is_slurm: import datetime dist.init_process_group( "nccl", rank=args.rank, world_size=args.world_size, timeout=datetime.timedelta(seconds=30 * 60), ) torch.cuda.set_device(device) create_local_process_group() local_broadcast_process_authkey() print( f"Start running DDP on: {args.rank} (local: {args.local_rank}) with seed {seed + args.rank}." ) config["training"]["batch_size"] = int( config["training"]["batch_size"] / args.world_size ) dist.barrier() # Fix seed # Different for every machine to avoid sampling # the same element across machines seed = seed + args.rank random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) os.environ["PYTHONHASHSEED"] = str(seed) batch_size = config["training"]["batch_size"] if is_main_process(): print("Config: ", args.config_file) print( f"Torch version:{torch.__version__}, cuda:{torch.version.cuda}, cudnn:{torch.backends.cudnn.version()}, threads:{torch.get_num_threads()}" ) print("BatchSize per GPU: ", batch_size) print( f"Divided into {config['training']['nsteps_accumulation_gradient']} accumulation step" ) ############################## ########### MODEL ############ ############################## # Build model model = UniK3D(config).to(device) model.eval() print(f"MODEL: {model.__class__.__name__} at {model.device}") torch.cuda.empty_cache() if args.distributed: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DistributedDataParallel( model, find_unused_parameters=False, device_ids=[device], output_device=device, ) ############################## ######### OPTIMIZER ########## ############################## dtype_16bit = config["training"]["f16"] is_16bit = dtype_16bit != "f32" clipping = config["training"].get("clipping", None) # Optimize ddp_model = model.module if args.distributed else model params = ddp_model.get_params(config) optimizer = optim.AdamW( params, eps=6e-8 if is_16bit else 1e-8, # smallest subnormal fp16 number is 5.96e-8 # amsgrad=is_16bit, # use max instead of avg v_hat, avoid small number divisions? ) # Load Model: step = 0 if config["training"].get("pretrained", None) is not None: ddp_model.load_pretrained(config["training"]["pretrained"]) pretrained = torch.load( config["training"]["pretrained"], map_location="cpu", weights_only=False ) try: optimizer.load_state_dict(pretrained["optimizer"]) except Exception as e: if is_main_process(): print("Could not load optimizer state dict:", e) step = pretrained.get("step", 0) ddp_model.pixel_decoder.steps = step # EMA ema_class = ( ExponentialMovingAverage if config["training"]["ema"] > 0.0 else DummyExponentialMovingAverage ) ema_handle = ema_class( ddp_model.parameters_grad(), 1 - (1 - config["training"]["ema"]) * EMA_INTERVAL, update_after_step=config["training"]["warmup_iters"] / EMA_INTERVAL, switch=True, tau=EMA_TAU // EMA_INTERVAL, ) setattr(ema_handle, "num_updates", step // EMA_INTERVAL) ############################## ######### GENERICS ########### ############################## resize_method = config["data"].get("resize_method", "hard") crop = config["data"].get("crop", "garg") augmentations_db = config["data"].get("augmentations", {}) shape_constraints = config["data"].get("shape_constraints", {}) image_shape = config["data"]["image_shape"] mini = config["data"]["mini"] nsteps_accumulation_gradient = config["training"]["nsteps_accumulation_gradient"] batch_size = config["training"]["batch_size"] clipping_fn = torch.nn.utils.clip_grad_norm_ is_shell = int(os.environ.get("SHELL_JOB", 0)) run_id = sync_string_across_gpus( [f"{dt.now().strftime('%d-%h_%H-%M')}-{uuid.uuid4()}"], device )[0] if not is_shell and is_main_process(): repo_folder = os.path.dirname(os.path.realpath(__file__)) try: repo = git.Repo(repo_folder) current_head = repo.head if repo.head.is_detached else repo.active_branch notes = f"MESSAGE: {current_head.commit.message} HASH:{current_head.commit.hexsha} BRANCH:{current_head.name}" except: print(f"problem with {repo_folder}, does it exist?") notes = "" # restore the original batchsize, not acquired by other calls from now on if args.distributed: config["training"]["batch_size"] = ( config["training"]["batch_size"] * args.world_size ) wandb.init( project="UniK3D", name=run_id, config=config, tags=None, notes=notes, dir=os.environ.get("WANDB_HOME", os.environ.get("TMPDIR", "/tmp")), ) wandb.watch(model) ############################## ########## DATASET ########### ############################## # Datasets loading train_datasets, val_datasets = {}, {} if is_main_process(): print("Loading training datasets...") dims = 0 for dataset in config["data"]["train_datasets"]: assert hasattr(datasets, dataset), f"{dataset} not a custom dataset" train_dataset: datasets.BaseDataset = getattr(datasets, dataset) train_datasets[dataset] = train_dataset( image_shape=image_shape, split_file=train_dataset.train_split, test_mode=False, crop=crop, augmentations_db=augmentations_db, shape_constraints=shape_constraints, normalize=config["data"].get("normalization", "imagenet"), resize_method=resize_method, mini=mini, num_frames=config["data"].get("num_frames", 1), fps_range=[1, 5], num_copies=config["data"]["pair"], ) dim = ( train_datasets[dataset].dataset._addr.numel() * 8 + train_datasets[dataset].dataset._lst.numel() ) / (2**20) if hasattr(train_datasets[dataset], "sequences"): dim += ( train_datasets[dataset].sequences._addr.numel() * 8 + train_datasets[dataset].sequences._lst.numel() ) / (2**20) dims = dims + dim if is_main_process(): print(f"{dataset}: {dim:.1f}MB") print(f"All training datasets loaded, with total size: {dims:.1f}MB") barrier() assert batch_size % config["data"]["pair"] == 0 batch_size = batch_size // config["data"]["pair"] assert batch_size % nsteps_accumulation_gradient == 0 batch_chunk = batch_size // nsteps_accumulation_gradient train_dataset = ConcatDataset( list(train_datasets.values()), shape_constraints=shape_constraints, ) if is_main_process(): print("Loading validation datasets...") for dataset in config["data"]["val_datasets"]: val_dataset: datasets.BaseDataset = getattr(datasets, dataset) val_datasets[dataset] = val_dataset( image_shape=image_shape, split_file=val_dataset.test_split, test_mode=True, crop=crop, shape_constraints=shape_constraints, augmentations_db=augmentations_db, normalize=config["data"].get("normalization", "imagenet"), resize_method=resize_method, num_frames=1, mini=1.0, num_copies=1, ) # Dataset samplers, create distributed sampler pinned to rank if args.distributed: sampling = deepcopy(config["data"]["sampling"]) weights, num_samples = get_weights(train_datasets, sampling) train_sampler = torch.utils.data.WeightedRandomSampler( weights, num_samples, replacement=True ) valid_samplers = { k: DistributedSamplerNoDuplicate( v, num_replicas=args.world_size, rank=args.rank, shuffle=False, drop_last=False, ) for k, v in val_datasets.items() } else: train_sampler = RandomSampler(train_dataset) valid_samplers = {k: SequentialSampler(v) for k, v in val_datasets.items()} train_sampler = torch.utils.data.BatchSampler( train_sampler, batch_size=batch_size, drop_last=True ) # Dataset loader val_batch_size = 1 num_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", 4)) train_loader = DataLoader( train_dataset, num_workers=num_workers, sampler=train_sampler, pin_memory=True, collate_fn=partial(collate_fn, is_batched=True), persistent_workers=True if num_workers else None, ) val_loaders = { name_dataset: DataLoader( dataset, batch_size=val_batch_size, shuffle=False, num_workers=num_workers, sampler=valid_samplers[name_dataset], pin_memory=True, drop_last=False, collate_fn=partial(collate_fn, is_batched=False), ) for name_dataset, dataset in val_datasets.items() } # SCHEDULERS! scheduler_wd = CosineScheduler( optimizer, key="weight_decay", init_value=config["training"]["wd"], base_value=config["training"]["wd"], final_value=config["training"]["wd_final"], warmup_iters=0, total_iters=config["training"]["n_iters"], flat_iters=config["training"]["warmup_iters"], step_init=step - 1, ) scheduler_lr = CosineScheduler( optimizer, key="lr", init_value=config["training"]["lr"] * config["training"].get("lr_warmup", 1.0), final_value=config["training"]["lr_final"], warmup_iters=5000, flat_iters=config["training"]["warmup_iters"], total_iters=config["training"]["n_iters"], step_init=step - 1, ) scheduler_betas = CosineScheduler( optimizer, key="betas", init_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, base_value=0.85 if config["training"].get("cycle_betas", True) else 0.9, final_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, warmup_iters=config["training"]["warmup_iters"], total_iters=config["training"]["n_iters"], step_init=step - 1, ) # Set loss scaler for half precision training + sanity zeroing grads dtype = MAP_DTYPE[dtype_16bit] if not torch.cuda.is_bf16_supported() and is_16bit: dtype = torch.float16 context = torch.autocast(device_type="cuda", dtype=dtype, enabled=is_16bit) # use float16 to check for instability at inference an avoid bfloat16 for coarseness context_val = torch.autocast( device_type="cuda", dtype=torch.float16, enabled=is_16bit ) optimizer.zero_grad(set_to_none=True) ############################## ########## TRAINING ########## ############################## # Remember that if i-th layer is frozen, this will break gradient checkpointing # in layer i+1-th. This is because CheckpointFunction treats the i+1-th input as # without gradient, thus the i+1-th layer does not have grads (?). To solve it, # just add requires_grad_() to the inputs coming from the frozen layer ddp_model.train() start = time() n_steps = config["training"]["n_iters"] init_steps = int(step) track_pbar = is_shell if is_main_process(): print("Is a shell job?", is_shell) print("Use dtype:", dtype if is_16bit else torch.float32) print( f'Train for {config["training"]["n_iters"]} steps, validate every {config["training"]["validation_interval"]} steps' ) print(f"START with {num_workers} workers") if track_pbar: pbar = tqdm(total=n_steps - init_steps) scaler = torch.amp.GradScaler( "cuda", init_scale=2**14 if dtype_16bit == "f16" else 2**40, enabled=is_16bit, growth_factor=1.2, backoff_factor=0.8, growth_interval=500, ) track_losses, track_grad = {}, {} system_memory = dict(psutil.virtual_memory()._asdict())["available"] / 2**30 cpid_memory = current_process.memory_info()[0] / 2.0**30 gpu_mem = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30 while True: for j, batches in enumerate(train_loader): system_memory = ( 0.99 * system_memory + 0.01 * dict(psutil.virtual_memory()._asdict())["available"] / 2**30 ) cpid_memory = ( 0.99 * cpid_memory + 0.01 * current_process.memory_info()[0] / 2.0**30 ) gpu_mem = ( 0.99 * gpu_mem + 0.01 * (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30 ) if j % 1000 == 0 and is_main_process(): print(f"System information at step {j}") print(f"System-wide RAM available: {system_memory:.2f}GB") print(f"CPU utilization: {psutil.cpu_percent(interval=None)}%") print(f"GPU memory utilized: {gpu_mem:.2f}GB") batches["data"] = { k: v.to(model.device, non_blocking=True) for k, v in batches["data"].items() } for idx in range(nsteps_accumulation_gradient): batch = {} batch_slice = slice(idx * batch_chunk, (idx + 1) * batch_chunk) batch["data"] = {k: v[batch_slice] for k, v in batches["data"].items()} batch["img_metas"] = batches["img_metas"][batch_slice] with ( model.no_sync() if idx < nsteps_accumulation_gradient - 1 else nullcontext() ): with context: preds, losses = model(batch["data"], batch["img_metas"]) loss = sum(losses["opt"].values()) scaler.scale(loss).backward() losses_dict = { k: v.detach() for loss in losses.values() for k, v in loss.items() } track_losses.update( { k: track_losses.get(k, 0.0) + torch.nan_to_num(v, nan=1e5, posinf=1e5, neginf=1e5) for k, v in losses_dict.items() } ) ddp_model.loss_history = track_losses if clipping is not None: scaler.unscale_(optimizer) grad_norm = clipping_fn(ddp_model.parameters_grad(), clipping) if torch.isfinite(grad_norm): track_losses.update( {"Grad_Norm": track_losses.get("Grad_Norm", 0.0) + grad_norm} ) # there is a deeper issue, either log/sqrt of negative loss # or the inputs create large values and destroy model weights if is_16bit and scaler.get_scale() < 1: raise ValueError("Scale went less than 1, ISSUE!!!") scaler.step(optimizer) scaler.update() scheduler_wd.step() scheduler_lr.step() scheduler_betas.step() model.module.step() optimizer.zero_grad(set_to_none=True) if step % EMA_INTERVAL == 0: ema_handle.update() if is_main_process() and track_pbar: pbar.update(1) step += 1 # LOGGING if step % 100 == 0 and is_main_process(): log_num = min(10, preds["depth"].shape[0]) log_train_artifacts( batch["data"]["image"][-log_num:, 0].float(), ( batch["data"]["depth"][-log_num:, 0].float() if "depth" in batch["data"] else [] ), preds["depth"][-log_num:, 0].detach().float(), infos={ k: v[-log_num:, 0] for k, v in preds.get("infos", {}).items() }, step=step, ) if step % 50 == 0: track_losses = { k: v / (50 * nsteps_accumulation_gradient) for k, v in track_losses.items() } # grad norm is for every step! track_losses["Grad_Norm"] = ( track_losses["Grad_Norm"] * nsteps_accumulation_gradient ) track_losses = aggregate_sync_losses(track_losses, device=model.device) if is_main_process(): elapsed = int(time() - start) eta = int(elapsed * (n_steps - step) / max(1, step - init_steps)) print( f"Step {step}/{n_steps} [{format_seconds(elapsed)}<{format_seconds(eta)}]" ) try: wandb.log( { **{f"Train/{k}": v for k, v in track_losses.items()}, **{f"Train/lr": scheduler_lr.get()[-1]}, **{f"Train/wd": scheduler_wd.get()[-2]}, **{f"Train/scale_f16": log2(scaler.get_scale())}, }, step=step, ) except Exception as e: print("Not logging loss because of:", e) if step % 100 == 0: log_loss_dict = { f"Train/{k}": v for k, v in track_losses.items() } print( ", ".join( [f"{k}: {v:.5f}" for k, v in log_loss_dict.items()] ) ) track_losses = {} # reinit every 50 steps, average the current 50 steps # Validation is_last_step = step >= config["training"]["n_iters"] is_validation = step % config["training"]["validation_interval"] == 0 if is_last_step or is_validation: torch.cuda.empty_cache() barrier() if is_main_process(): print(f"Validation at {step}th step...") ddp_model.eval() start_validation = time() with torch.no_grad(), ema_handle.average_parameters(): validate( model, test_loaders=val_loaders, step=step, run_id=run_id, idxs=(64, 96, 224, 256), # random context=context_val, ) if is_main_process(): print(f"Elapsed: {format_seconds(int(time() - start_validation))}") ddp_model.train() torch.cuda.empty_cache() if step >= config["training"]["n_iters"]: if is_main_process() and track_pbar: pbar.close() wandb.finish(0) dist.destroy_process_group() return 0 if __name__ == "__main__": if "SLURM_PROCID" in os.environ: os.environ["TRITON_CACHE_DIR"] = "/tmp" # Arguments parser = argparse.ArgumentParser( description="Training script", conflict_handler="resolve" ) parser.add_argument("--config-file", type=str, required=True) parser.add_argument("--master-port", type=str) parser.add_argument("--distributed", action="store_true") parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() with open(args.config_file, "r") as f: config = json.load(f) deterministic = config["generic"].get("deterministic", True) torch.backends.cudnn.deterministic = deterministic torch.backends.cudnn.benchmark = not deterministic torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True torch.set_float32_matmul_precision("high") torch.backends.cuda.enable_mem_efficient_sdp(False) torch.set_num_threads(1) main_worker(config, args)