Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import json | |
import os | |
import random | |
import uuid | |
from contextlib import nullcontext | |
from copy import deepcopy | |
from datetime import datetime as dt | |
from functools import partial | |
from math import log2 | |
from time import sleep, time | |
from typing import Any, Dict | |
import git | |
import numpy as np | |
import psutil | |
import torch | |
import torch.nn as nn | |
import torch.utils.data.distributed | |
import wandb | |
from PIL import Image | |
from torch import distributed as dist | |
from torch import optim | |
from torch.nn.parallel.distributed import DistributedDataParallel | |
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler | |
from tqdm import tqdm | |
import unik3d.datasets as datasets | |
from unik3d.datasets import (ConcatDataset, DistributedSamplerNoDuplicate, | |
collate_fn, get_weights) | |
from unik3d.models import UniK3D | |
from unik3d.ops.scheduler import CosineScheduler | |
from unik3d.utils import (barrier, format_seconds, is_main_process, | |
log_train_artifacts, validate) | |
from unik3d.utils.distributed import (create_local_process_group, | |
local_broadcast_process_authkey, | |
setup_multi_processes, setup_slurm, | |
sync_string_across_gpus, | |
sync_tensor_across_gpus) | |
from unik3d.utils.ema_torch import (DummyExponentialMovingAverage, | |
ExponentialMovingAverage) | |
from unik3d.utils.misc import calculate_mean_values | |
EMA_INTERVAL = 10 | |
EMA_TAU = 10000 | |
EMA_START = 50000 | |
MAP_DTYPE = { | |
"f16": torch.float16, | |
"bf16": torch.bfloat16, | |
"f32": torch.float32, | |
} | |
def aggregate_sync_losses(dict_: dict[str, torch.Tensor], device): | |
keys = list(dict_.keys()) | |
values = torch.tensor(list(dict_.values()), device=device) | |
keys = sync_string_across_gpus(keys, device) | |
values = sync_tensor_across_gpus(values, dim=0).cpu().tolist() | |
dict_ = calculate_mean_values(keys, values) | |
return dict_ | |
def main_worker(config: Dict[str, Any], args: argparse.Namespace): | |
current_process = psutil.Process(os.getpid()) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
seed = config["generic"]["seed"] | |
if not args.distributed: | |
args.rank = 0 | |
args.local_rank = 0 | |
args.world_size = 1 | |
else: | |
# initializes the distributed backend which will take care of synchronizing nodes/GPUs | |
setup_multi_processes(config) | |
is_slurm = "SLURM_PROCID" in os.environ | |
if is_slurm: | |
setup_slurm("nccl", port=args.master_port) | |
args.rank = int(os.environ["RANK"]) | |
args.world_size = int(os.environ["WORLD_SIZE"]) | |
args.local_rank = device = int(os.environ["LOCAL_RANK"]) | |
if not is_slurm: | |
import datetime | |
dist.init_process_group( | |
"nccl", | |
rank=args.rank, | |
world_size=args.world_size, | |
timeout=datetime.timedelta(seconds=30 * 60), | |
) | |
torch.cuda.set_device(device) | |
create_local_process_group() | |
local_broadcast_process_authkey() | |
print( | |
f"Start running DDP on: {args.rank} (local: {args.local_rank}) with seed {seed + args.rank}." | |
) | |
config["training"]["batch_size"] = int( | |
config["training"]["batch_size"] / args.world_size | |
) | |
dist.barrier() | |
# Fix seed | |
# Different for every machine to avoid sampling | |
# the same element across machines | |
seed = seed + args.rank | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
batch_size = config["training"]["batch_size"] | |
if is_main_process(): | |
print("Config: ", args.config_file) | |
print( | |
f"Torch version:{torch.__version__}, cuda:{torch.version.cuda}, cudnn:{torch.backends.cudnn.version()}, threads:{torch.get_num_threads()}" | |
) | |
print("BatchSize per GPU: ", batch_size) | |
print( | |
f"Divided into {config['training']['nsteps_accumulation_gradient']} accumulation step" | |
) | |
############################## | |
########### MODEL ############ | |
############################## | |
# Build model | |
model = UniK3D(config).to(device) | |
model.eval() | |
print(f"MODEL: {model.__class__.__name__} at {model.device}") | |
torch.cuda.empty_cache() | |
if args.distributed: | |
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
model = DistributedDataParallel( | |
model, | |
find_unused_parameters=False, | |
device_ids=[device], | |
output_device=device, | |
) | |
############################## | |
######### OPTIMIZER ########## | |
############################## | |
dtype_16bit = config["training"]["f16"] | |
is_16bit = dtype_16bit != "f32" | |
clipping = config["training"].get("clipping", None) | |
# Optimize | |
ddp_model = model.module if args.distributed else model | |
params = ddp_model.get_params(config) | |
optimizer = optim.AdamW( | |
params, | |
eps=6e-8 if is_16bit else 1e-8, # smallest subnormal fp16 number is 5.96e-8 | |
# amsgrad=is_16bit, # use max instead of avg v_hat, avoid small number divisions? | |
) | |
# Load Model: | |
step = 0 | |
if config["training"].get("pretrained", None) is not None: | |
ddp_model.load_pretrained(config["training"]["pretrained"]) | |
pretrained = torch.load( | |
config["training"]["pretrained"], map_location="cpu", weights_only=False | |
) | |
try: | |
optimizer.load_state_dict(pretrained["optimizer"]) | |
except Exception as e: | |
if is_main_process(): | |
print("Could not load optimizer state dict:", e) | |
step = pretrained.get("step", 0) | |
ddp_model.pixel_decoder.steps = step | |
# EMA | |
ema_class = ( | |
ExponentialMovingAverage | |
if config["training"]["ema"] > 0.0 | |
else DummyExponentialMovingAverage | |
) | |
ema_handle = ema_class( | |
ddp_model.parameters_grad(), | |
1 - (1 - config["training"]["ema"]) * EMA_INTERVAL, | |
update_after_step=config["training"]["warmup_iters"] / EMA_INTERVAL, | |
switch=True, | |
tau=EMA_TAU // EMA_INTERVAL, | |
) | |
setattr(ema_handle, "num_updates", step // EMA_INTERVAL) | |
############################## | |
######### GENERICS ########### | |
############################## | |
resize_method = config["data"].get("resize_method", "hard") | |
crop = config["data"].get("crop", "garg") | |
augmentations_db = config["data"].get("augmentations", {}) | |
shape_constraints = config["data"].get("shape_constraints", {}) | |
image_shape = config["data"]["image_shape"] | |
mini = config["data"]["mini"] | |
nsteps_accumulation_gradient = config["training"]["nsteps_accumulation_gradient"] | |
batch_size = config["training"]["batch_size"] | |
clipping_fn = torch.nn.utils.clip_grad_norm_ | |
is_shell = int(os.environ.get("SHELL_JOB", 0)) | |
run_id = sync_string_across_gpus( | |
[f"{dt.now().strftime('%d-%h_%H-%M')}-{uuid.uuid4()}"], device | |
)[0] | |
if not is_shell and is_main_process(): | |
repo_folder = os.path.dirname(os.path.realpath(__file__)) | |
try: | |
repo = git.Repo(repo_folder) | |
current_head = repo.head if repo.head.is_detached else repo.active_branch | |
notes = f"MESSAGE: {current_head.commit.message} HASH:{current_head.commit.hexsha} BRANCH:{current_head.name}" | |
except: | |
print(f"problem with {repo_folder}, does it exist?") | |
notes = "" | |
# restore the original batchsize, not acquired by other calls from now on | |
if args.distributed: | |
config["training"]["batch_size"] = ( | |
config["training"]["batch_size"] * args.world_size | |
) | |
wandb.init( | |
project="UniK3D", | |
name=run_id, | |
config=config, | |
tags=None, | |
notes=notes, | |
dir=os.environ.get("WANDB_HOME", os.environ.get("TMPDIR", "/tmp")), | |
) | |
wandb.watch(model) | |
############################## | |
########## DATASET ########### | |
############################## | |
# Datasets loading | |
train_datasets, val_datasets = {}, {} | |
if is_main_process(): | |
print("Loading training datasets...") | |
dims = 0 | |
for dataset in config["data"]["train_datasets"]: | |
assert hasattr(datasets, dataset), f"{dataset} not a custom dataset" | |
train_dataset: datasets.BaseDataset = getattr(datasets, dataset) | |
train_datasets[dataset] = train_dataset( | |
image_shape=image_shape, | |
split_file=train_dataset.train_split, | |
test_mode=False, | |
crop=crop, | |
augmentations_db=augmentations_db, | |
shape_constraints=shape_constraints, | |
normalize=config["data"].get("normalization", "imagenet"), | |
resize_method=resize_method, | |
mini=mini, | |
num_frames=config["data"].get("num_frames", 1), | |
fps_range=[1, 5], | |
num_copies=config["data"]["pair"], | |
) | |
dim = ( | |
train_datasets[dataset].dataset._addr.numel() * 8 | |
+ train_datasets[dataset].dataset._lst.numel() | |
) / (2**20) | |
if hasattr(train_datasets[dataset], "sequences"): | |
dim += ( | |
train_datasets[dataset].sequences._addr.numel() * 8 | |
+ train_datasets[dataset].sequences._lst.numel() | |
) / (2**20) | |
dims = dims + dim | |
if is_main_process(): | |
print(f"{dataset}: {dim:.1f}MB") | |
print(f"All training datasets loaded, with total size: {dims:.1f}MB") | |
barrier() | |
assert batch_size % config["data"]["pair"] == 0 | |
batch_size = batch_size // config["data"]["pair"] | |
assert batch_size % nsteps_accumulation_gradient == 0 | |
batch_chunk = batch_size // nsteps_accumulation_gradient | |
train_dataset = ConcatDataset( | |
list(train_datasets.values()), | |
shape_constraints=shape_constraints, | |
) | |
if is_main_process(): | |
print("Loading validation datasets...") | |
for dataset in config["data"]["val_datasets"]: | |
val_dataset: datasets.BaseDataset = getattr(datasets, dataset) | |
val_datasets[dataset] = val_dataset( | |
image_shape=image_shape, | |
split_file=val_dataset.test_split, | |
test_mode=True, | |
crop=crop, | |
shape_constraints=shape_constraints, | |
augmentations_db=augmentations_db, | |
normalize=config["data"].get("normalization", "imagenet"), | |
resize_method=resize_method, | |
num_frames=1, | |
mini=1.0, | |
num_copies=1, | |
) | |
# Dataset samplers, create distributed sampler pinned to rank | |
if args.distributed: | |
sampling = deepcopy(config["data"]["sampling"]) | |
weights, num_samples = get_weights(train_datasets, sampling) | |
train_sampler = torch.utils.data.WeightedRandomSampler( | |
weights, num_samples, replacement=True | |
) | |
valid_samplers = { | |
k: DistributedSamplerNoDuplicate( | |
v, | |
num_replicas=args.world_size, | |
rank=args.rank, | |
shuffle=False, | |
drop_last=False, | |
) | |
for k, v in val_datasets.items() | |
} | |
else: | |
train_sampler = RandomSampler(train_dataset) | |
valid_samplers = {k: SequentialSampler(v) for k, v in val_datasets.items()} | |
train_sampler = torch.utils.data.BatchSampler( | |
train_sampler, batch_size=batch_size, drop_last=True | |
) | |
# Dataset loader | |
val_batch_size = 1 | |
num_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", 4)) | |
train_loader = DataLoader( | |
train_dataset, | |
num_workers=num_workers, | |
sampler=train_sampler, | |
pin_memory=True, | |
collate_fn=partial(collate_fn, is_batched=True), | |
persistent_workers=True if num_workers else None, | |
) | |
val_loaders = { | |
name_dataset: DataLoader( | |
dataset, | |
batch_size=val_batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
sampler=valid_samplers[name_dataset], | |
pin_memory=True, | |
drop_last=False, | |
collate_fn=partial(collate_fn, is_batched=False), | |
) | |
for name_dataset, dataset in val_datasets.items() | |
} | |
# SCHEDULERS! | |
scheduler_wd = CosineScheduler( | |
optimizer, | |
key="weight_decay", | |
init_value=config["training"]["wd"], | |
base_value=config["training"]["wd"], | |
final_value=config["training"]["wd_final"], | |
warmup_iters=0, | |
total_iters=config["training"]["n_iters"], | |
flat_iters=config["training"]["warmup_iters"], | |
step_init=step - 1, | |
) | |
scheduler_lr = CosineScheduler( | |
optimizer, | |
key="lr", | |
init_value=config["training"]["lr"] * config["training"].get("lr_warmup", 1.0), | |
final_value=config["training"]["lr_final"], | |
warmup_iters=5000, | |
flat_iters=config["training"]["warmup_iters"], | |
total_iters=config["training"]["n_iters"], | |
step_init=step - 1, | |
) | |
scheduler_betas = CosineScheduler( | |
optimizer, | |
key="betas", | |
init_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, | |
base_value=0.85 if config["training"].get("cycle_betas", True) else 0.9, | |
final_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, | |
warmup_iters=config["training"]["warmup_iters"], | |
total_iters=config["training"]["n_iters"], | |
step_init=step - 1, | |
) | |
# Set loss scaler for half precision training + sanity zeroing grads | |
dtype = MAP_DTYPE[dtype_16bit] | |
if not torch.cuda.is_bf16_supported() and is_16bit: | |
dtype = torch.float16 | |
context = torch.autocast(device_type="cuda", dtype=dtype, enabled=is_16bit) | |
# use float16 to check for instability at inference an avoid bfloat16 for coarseness | |
context_val = torch.autocast( | |
device_type="cuda", dtype=torch.float16, enabled=is_16bit | |
) | |
optimizer.zero_grad(set_to_none=True) | |
############################## | |
########## TRAINING ########## | |
############################## | |
# Remember that if i-th layer is frozen, this will break gradient checkpointing | |
# in layer i+1-th. This is because CheckpointFunction treats the i+1-th input as | |
# without gradient, thus the i+1-th layer does not have grads (?). To solve it, | |
# just add requires_grad_() to the inputs coming from the frozen layer | |
ddp_model.train() | |
start = time() | |
n_steps = config["training"]["n_iters"] | |
init_steps = int(step) | |
track_pbar = is_shell | |
if is_main_process(): | |
print("Is a shell job?", is_shell) | |
print("Use dtype:", dtype if is_16bit else torch.float32) | |
print( | |
f'Train for {config["training"]["n_iters"]} steps, validate every {config["training"]["validation_interval"]} steps' | |
) | |
print(f"START with {num_workers} workers") | |
if track_pbar: | |
pbar = tqdm(total=n_steps - init_steps) | |
scaler = torch.amp.GradScaler( | |
"cuda", | |
init_scale=2**14 if dtype_16bit == "f16" else 2**40, | |
enabled=is_16bit, | |
growth_factor=1.2, | |
backoff_factor=0.8, | |
growth_interval=500, | |
) | |
track_losses, track_grad = {}, {} | |
system_memory = dict(psutil.virtual_memory()._asdict())["available"] / 2**30 | |
cpid_memory = current_process.memory_info()[0] / 2.0**30 | |
gpu_mem = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30 | |
while True: | |
for j, batches in enumerate(train_loader): | |
system_memory = ( | |
0.99 * system_memory | |
+ 0.01 * dict(psutil.virtual_memory()._asdict())["available"] / 2**30 | |
) | |
cpid_memory = ( | |
0.99 * cpid_memory + 0.01 * current_process.memory_info()[0] / 2.0**30 | |
) | |
gpu_mem = ( | |
0.99 * gpu_mem | |
+ 0.01 | |
* (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) | |
/ 2**30 | |
) | |
if j % 1000 == 0 and is_main_process(): | |
print(f"System information at step {j}") | |
print(f"System-wide RAM available: {system_memory:.2f}GB") | |
print(f"CPU utilization: {psutil.cpu_percent(interval=None)}%") | |
print(f"GPU memory utilized: {gpu_mem:.2f}GB") | |
batches["data"] = { | |
k: v.to(model.device, non_blocking=True) | |
for k, v in batches["data"].items() | |
} | |
for idx in range(nsteps_accumulation_gradient): | |
batch = {} | |
batch_slice = slice(idx * batch_chunk, (idx + 1) * batch_chunk) | |
batch["data"] = {k: v[batch_slice] for k, v in batches["data"].items()} | |
batch["img_metas"] = batches["img_metas"][batch_slice] | |
with ( | |
model.no_sync() | |
if idx < nsteps_accumulation_gradient - 1 | |
else nullcontext() | |
): | |
with context: | |
preds, losses = model(batch["data"], batch["img_metas"]) | |
loss = sum(losses["opt"].values()) | |
scaler.scale(loss).backward() | |
losses_dict = { | |
k: v.detach() for loss in losses.values() for k, v in loss.items() | |
} | |
track_losses.update( | |
{ | |
k: track_losses.get(k, 0.0) | |
+ torch.nan_to_num(v, nan=1e5, posinf=1e5, neginf=1e5) | |
for k, v in losses_dict.items() | |
} | |
) | |
ddp_model.loss_history = track_losses | |
if clipping is not None: | |
scaler.unscale_(optimizer) | |
grad_norm = clipping_fn(ddp_model.parameters_grad(), clipping) | |
if torch.isfinite(grad_norm): | |
track_losses.update( | |
{"Grad_Norm": track_losses.get("Grad_Norm", 0.0) + grad_norm} | |
) | |
# there is a deeper issue, either log/sqrt of negative loss | |
# or the inputs create large values and destroy model weights | |
if is_16bit and scaler.get_scale() < 1: | |
raise ValueError("Scale went less than 1, ISSUE!!!") | |
scaler.step(optimizer) | |
scaler.update() | |
scheduler_wd.step() | |
scheduler_lr.step() | |
scheduler_betas.step() | |
model.module.step() | |
optimizer.zero_grad(set_to_none=True) | |
if step % EMA_INTERVAL == 0: | |
ema_handle.update() | |
if is_main_process() and track_pbar: | |
pbar.update(1) | |
step += 1 | |
# LOGGING | |
if step % 100 == 0 and is_main_process(): | |
log_num = min(10, preds["depth"].shape[0]) | |
log_train_artifacts( | |
batch["data"]["image"][-log_num:, 0].float(), | |
( | |
batch["data"]["depth"][-log_num:, 0].float() | |
if "depth" in batch["data"] | |
else [] | |
), | |
preds["depth"][-log_num:, 0].detach().float(), | |
infos={ | |
k: v[-log_num:, 0] for k, v in preds.get("infos", {}).items() | |
}, | |
step=step, | |
) | |
if step % 50 == 0: | |
track_losses = { | |
k: v / (50 * nsteps_accumulation_gradient) | |
for k, v in track_losses.items() | |
} | |
# grad norm is for every step! | |
track_losses["Grad_Norm"] = ( | |
track_losses["Grad_Norm"] * nsteps_accumulation_gradient | |
) | |
track_losses = aggregate_sync_losses(track_losses, device=model.device) | |
if is_main_process(): | |
elapsed = int(time() - start) | |
eta = int(elapsed * (n_steps - step) / max(1, step - init_steps)) | |
print( | |
f"Step {step}/{n_steps} [{format_seconds(elapsed)}<{format_seconds(eta)}]" | |
) | |
try: | |
wandb.log( | |
{ | |
**{f"Train/{k}": v for k, v in track_losses.items()}, | |
**{f"Train/lr": scheduler_lr.get()[-1]}, | |
**{f"Train/wd": scheduler_wd.get()[-2]}, | |
**{f"Train/scale_f16": log2(scaler.get_scale())}, | |
}, | |
step=step, | |
) | |
except Exception as e: | |
print("Not logging loss because of:", e) | |
if step % 100 == 0: | |
log_loss_dict = { | |
f"Train/{k}": v for k, v in track_losses.items() | |
} | |
print( | |
", ".join( | |
[f"{k}: {v:.5f}" for k, v in log_loss_dict.items()] | |
) | |
) | |
track_losses = {} # reinit every 50 steps, average the current 50 steps | |
# Validation | |
is_last_step = step >= config["training"]["n_iters"] | |
is_validation = step % config["training"]["validation_interval"] == 0 | |
if is_last_step or is_validation: | |
torch.cuda.empty_cache() | |
barrier() | |
if is_main_process(): | |
print(f"Validation at {step}th step...") | |
ddp_model.eval() | |
start_validation = time() | |
with torch.no_grad(), ema_handle.average_parameters(): | |
validate( | |
model, | |
test_loaders=val_loaders, | |
step=step, | |
run_id=run_id, | |
idxs=(64, 96, 224, 256), # random | |
context=context_val, | |
) | |
if is_main_process(): | |
print(f"Elapsed: {format_seconds(int(time() - start_validation))}") | |
ddp_model.train() | |
torch.cuda.empty_cache() | |
if step >= config["training"]["n_iters"]: | |
if is_main_process() and track_pbar: | |
pbar.close() | |
wandb.finish(0) | |
dist.destroy_process_group() | |
return 0 | |
if __name__ == "__main__": | |
if "SLURM_PROCID" in os.environ: | |
os.environ["TRITON_CACHE_DIR"] = "/tmp" | |
# Arguments | |
parser = argparse.ArgumentParser( | |
description="Training script", conflict_handler="resolve" | |
) | |
parser.add_argument("--config-file", type=str, required=True) | |
parser.add_argument("--master-port", type=str) | |
parser.add_argument("--distributed", action="store_true") | |
parser.add_argument("--local_rank", type=int, default=0) | |
args = parser.parse_args() | |
with open(args.config_file, "r") as f: | |
config = json.load(f) | |
deterministic = config["generic"].get("deterministic", True) | |
torch.backends.cudnn.deterministic = deterministic | |
torch.backends.cudnn.benchmark = not deterministic | |
torch.backends.cudnn.allow_tf32 = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.set_float32_matmul_precision("high") | |
torch.backends.cuda.enable_mem_efficient_sdp(False) | |
torch.set_num_threads(1) | |
main_worker(config, args) | |