Spaces:
Running
on
Zero
Running
on
Zero
File size: 24,032 Bytes
1ea89dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 |
import argparse
import json
import os
import random
import uuid
from contextlib import nullcontext
from copy import deepcopy
from datetime import datetime as dt
from functools import partial
from math import log2
from time import sleep, time
from typing import Any, Dict
import git
import numpy as np
import psutil
import torch
import torch.nn as nn
import torch.utils.data.distributed
import wandb
from PIL import Image
from torch import distributed as dist
from torch import optim
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
import unik3d.datasets as datasets
from unik3d.datasets import (ConcatDataset, DistributedSamplerNoDuplicate,
collate_fn, get_weights)
from unik3d.models import UniK3D
from unik3d.ops.scheduler import CosineScheduler
from unik3d.utils import (barrier, format_seconds, is_main_process,
log_train_artifacts, validate)
from unik3d.utils.distributed import (create_local_process_group,
local_broadcast_process_authkey,
setup_multi_processes, setup_slurm,
sync_string_across_gpus,
sync_tensor_across_gpus)
from unik3d.utils.ema_torch import (DummyExponentialMovingAverage,
ExponentialMovingAverage)
from unik3d.utils.misc import calculate_mean_values
EMA_INTERVAL = 10
EMA_TAU = 10000
EMA_START = 50000
MAP_DTYPE = {
"f16": torch.float16,
"bf16": torch.bfloat16,
"f32": torch.float32,
}
def aggregate_sync_losses(dict_: dict[str, torch.Tensor], device):
keys = list(dict_.keys())
values = torch.tensor(list(dict_.values()), device=device)
keys = sync_string_across_gpus(keys, device)
values = sync_tensor_across_gpus(values, dim=0).cpu().tolist()
dict_ = calculate_mean_values(keys, values)
return dict_
def main_worker(config: Dict[str, Any], args: argparse.Namespace):
current_process = psutil.Process(os.getpid())
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
seed = config["generic"]["seed"]
if not args.distributed:
args.rank = 0
args.local_rank = 0
args.world_size = 1
else:
# initializes the distributed backend which will take care of synchronizing nodes/GPUs
setup_multi_processes(config)
is_slurm = "SLURM_PROCID" in os.environ
if is_slurm:
setup_slurm("nccl", port=args.master_port)
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.local_rank = device = int(os.environ["LOCAL_RANK"])
if not is_slurm:
import datetime
dist.init_process_group(
"nccl",
rank=args.rank,
world_size=args.world_size,
timeout=datetime.timedelta(seconds=30 * 60),
)
torch.cuda.set_device(device)
create_local_process_group()
local_broadcast_process_authkey()
print(
f"Start running DDP on: {args.rank} (local: {args.local_rank}) with seed {seed + args.rank}."
)
config["training"]["batch_size"] = int(
config["training"]["batch_size"] / args.world_size
)
dist.barrier()
# Fix seed
# Different for every machine to avoid sampling
# the same element across machines
seed = seed + args.rank
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
batch_size = config["training"]["batch_size"]
if is_main_process():
print("Config: ", args.config_file)
print(
f"Torch version:{torch.__version__}, cuda:{torch.version.cuda}, cudnn:{torch.backends.cudnn.version()}, threads:{torch.get_num_threads()}"
)
print("BatchSize per GPU: ", batch_size)
print(
f"Divided into {config['training']['nsteps_accumulation_gradient']} accumulation step"
)
##############################
########### MODEL ############
##############################
# Build model
model = UniK3D(config).to(device)
model.eval()
print(f"MODEL: {model.__class__.__name__} at {model.device}")
torch.cuda.empty_cache()
if args.distributed:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(
model,
find_unused_parameters=False,
device_ids=[device],
output_device=device,
)
##############################
######### OPTIMIZER ##########
##############################
dtype_16bit = config["training"]["f16"]
is_16bit = dtype_16bit != "f32"
clipping = config["training"].get("clipping", None)
# Optimize
ddp_model = model.module if args.distributed else model
params = ddp_model.get_params(config)
optimizer = optim.AdamW(
params,
eps=6e-8 if is_16bit else 1e-8, # smallest subnormal fp16 number is 5.96e-8
# amsgrad=is_16bit, # use max instead of avg v_hat, avoid small number divisions?
)
# Load Model:
step = 0
if config["training"].get("pretrained", None) is not None:
ddp_model.load_pretrained(config["training"]["pretrained"])
pretrained = torch.load(
config["training"]["pretrained"], map_location="cpu", weights_only=False
)
try:
optimizer.load_state_dict(pretrained["optimizer"])
except Exception as e:
if is_main_process():
print("Could not load optimizer state dict:", e)
step = pretrained.get("step", 0)
ddp_model.pixel_decoder.steps = step
# EMA
ema_class = (
ExponentialMovingAverage
if config["training"]["ema"] > 0.0
else DummyExponentialMovingAverage
)
ema_handle = ema_class(
ddp_model.parameters_grad(),
1 - (1 - config["training"]["ema"]) * EMA_INTERVAL,
update_after_step=config["training"]["warmup_iters"] / EMA_INTERVAL,
switch=True,
tau=EMA_TAU // EMA_INTERVAL,
)
setattr(ema_handle, "num_updates", step // EMA_INTERVAL)
##############################
######### GENERICS ###########
##############################
resize_method = config["data"].get("resize_method", "hard")
crop = config["data"].get("crop", "garg")
augmentations_db = config["data"].get("augmentations", {})
shape_constraints = config["data"].get("shape_constraints", {})
image_shape = config["data"]["image_shape"]
mini = config["data"]["mini"]
nsteps_accumulation_gradient = config["training"]["nsteps_accumulation_gradient"]
batch_size = config["training"]["batch_size"]
clipping_fn = torch.nn.utils.clip_grad_norm_
is_shell = int(os.environ.get("SHELL_JOB", 0))
run_id = sync_string_across_gpus(
[f"{dt.now().strftime('%d-%h_%H-%M')}-{uuid.uuid4()}"], device
)[0]
if not is_shell and is_main_process():
repo_folder = os.path.dirname(os.path.realpath(__file__))
try:
repo = git.Repo(repo_folder)
current_head = repo.head if repo.head.is_detached else repo.active_branch
notes = f"MESSAGE: {current_head.commit.message} HASH:{current_head.commit.hexsha} BRANCH:{current_head.name}"
except:
print(f"problem with {repo_folder}, does it exist?")
notes = ""
# restore the original batchsize, not acquired by other calls from now on
if args.distributed:
config["training"]["batch_size"] = (
config["training"]["batch_size"] * args.world_size
)
wandb.init(
project="UniK3D",
name=run_id,
config=config,
tags=None,
notes=notes,
dir=os.environ.get("WANDB_HOME", os.environ.get("TMPDIR", "/tmp")),
)
wandb.watch(model)
##############################
########## DATASET ###########
##############################
# Datasets loading
train_datasets, val_datasets = {}, {}
if is_main_process():
print("Loading training datasets...")
dims = 0
for dataset in config["data"]["train_datasets"]:
assert hasattr(datasets, dataset), f"{dataset} not a custom dataset"
train_dataset: datasets.BaseDataset = getattr(datasets, dataset)
train_datasets[dataset] = train_dataset(
image_shape=image_shape,
split_file=train_dataset.train_split,
test_mode=False,
crop=crop,
augmentations_db=augmentations_db,
shape_constraints=shape_constraints,
normalize=config["data"].get("normalization", "imagenet"),
resize_method=resize_method,
mini=mini,
num_frames=config["data"].get("num_frames", 1),
fps_range=[1, 5],
num_copies=config["data"]["pair"],
)
dim = (
train_datasets[dataset].dataset._addr.numel() * 8
+ train_datasets[dataset].dataset._lst.numel()
) / (2**20)
if hasattr(train_datasets[dataset], "sequences"):
dim += (
train_datasets[dataset].sequences._addr.numel() * 8
+ train_datasets[dataset].sequences._lst.numel()
) / (2**20)
dims = dims + dim
if is_main_process():
print(f"{dataset}: {dim:.1f}MB")
print(f"All training datasets loaded, with total size: {dims:.1f}MB")
barrier()
assert batch_size % config["data"]["pair"] == 0
batch_size = batch_size // config["data"]["pair"]
assert batch_size % nsteps_accumulation_gradient == 0
batch_chunk = batch_size // nsteps_accumulation_gradient
train_dataset = ConcatDataset(
list(train_datasets.values()),
shape_constraints=shape_constraints,
)
if is_main_process():
print("Loading validation datasets...")
for dataset in config["data"]["val_datasets"]:
val_dataset: datasets.BaseDataset = getattr(datasets, dataset)
val_datasets[dataset] = val_dataset(
image_shape=image_shape,
split_file=val_dataset.test_split,
test_mode=True,
crop=crop,
shape_constraints=shape_constraints,
augmentations_db=augmentations_db,
normalize=config["data"].get("normalization", "imagenet"),
resize_method=resize_method,
num_frames=1,
mini=1.0,
num_copies=1,
)
# Dataset samplers, create distributed sampler pinned to rank
if args.distributed:
sampling = deepcopy(config["data"]["sampling"])
weights, num_samples = get_weights(train_datasets, sampling)
train_sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples, replacement=True
)
valid_samplers = {
k: DistributedSamplerNoDuplicate(
v,
num_replicas=args.world_size,
rank=args.rank,
shuffle=False,
drop_last=False,
)
for k, v in val_datasets.items()
}
else:
train_sampler = RandomSampler(train_dataset)
valid_samplers = {k: SequentialSampler(v) for k, v in val_datasets.items()}
train_sampler = torch.utils.data.BatchSampler(
train_sampler, batch_size=batch_size, drop_last=True
)
# Dataset loader
val_batch_size = 1
num_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", 4))
train_loader = DataLoader(
train_dataset,
num_workers=num_workers,
sampler=train_sampler,
pin_memory=True,
collate_fn=partial(collate_fn, is_batched=True),
persistent_workers=True if num_workers else None,
)
val_loaders = {
name_dataset: DataLoader(
dataset,
batch_size=val_batch_size,
shuffle=False,
num_workers=num_workers,
sampler=valid_samplers[name_dataset],
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, is_batched=False),
)
for name_dataset, dataset in val_datasets.items()
}
# SCHEDULERS!
scheduler_wd = CosineScheduler(
optimizer,
key="weight_decay",
init_value=config["training"]["wd"],
base_value=config["training"]["wd"],
final_value=config["training"]["wd_final"],
warmup_iters=0,
total_iters=config["training"]["n_iters"],
flat_iters=config["training"]["warmup_iters"],
step_init=step - 1,
)
scheduler_lr = CosineScheduler(
optimizer,
key="lr",
init_value=config["training"]["lr"] * config["training"].get("lr_warmup", 1.0),
final_value=config["training"]["lr_final"],
warmup_iters=5000,
flat_iters=config["training"]["warmup_iters"],
total_iters=config["training"]["n_iters"],
step_init=step - 1,
)
scheduler_betas = CosineScheduler(
optimizer,
key="betas",
init_value=0.95 if config["training"].get("cycle_betas", True) else 0.9,
base_value=0.85 if config["training"].get("cycle_betas", True) else 0.9,
final_value=0.95 if config["training"].get("cycle_betas", True) else 0.9,
warmup_iters=config["training"]["warmup_iters"],
total_iters=config["training"]["n_iters"],
step_init=step - 1,
)
# Set loss scaler for half precision training + sanity zeroing grads
dtype = MAP_DTYPE[dtype_16bit]
if not torch.cuda.is_bf16_supported() and is_16bit:
dtype = torch.float16
context = torch.autocast(device_type="cuda", dtype=dtype, enabled=is_16bit)
# use float16 to check for instability at inference an avoid bfloat16 for coarseness
context_val = torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=is_16bit
)
optimizer.zero_grad(set_to_none=True)
##############################
########## TRAINING ##########
##############################
# Remember that if i-th layer is frozen, this will break gradient checkpointing
# in layer i+1-th. This is because CheckpointFunction treats the i+1-th input as
# without gradient, thus the i+1-th layer does not have grads (?). To solve it,
# just add requires_grad_() to the inputs coming from the frozen layer
ddp_model.train()
start = time()
n_steps = config["training"]["n_iters"]
init_steps = int(step)
track_pbar = is_shell
if is_main_process():
print("Is a shell job?", is_shell)
print("Use dtype:", dtype if is_16bit else torch.float32)
print(
f'Train for {config["training"]["n_iters"]} steps, validate every {config["training"]["validation_interval"]} steps'
)
print(f"START with {num_workers} workers")
if track_pbar:
pbar = tqdm(total=n_steps - init_steps)
scaler = torch.amp.GradScaler(
"cuda",
init_scale=2**14 if dtype_16bit == "f16" else 2**40,
enabled=is_16bit,
growth_factor=1.2,
backoff_factor=0.8,
growth_interval=500,
)
track_losses, track_grad = {}, {}
system_memory = dict(psutil.virtual_memory()._asdict())["available"] / 2**30
cpid_memory = current_process.memory_info()[0] / 2.0**30
gpu_mem = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30
while True:
for j, batches in enumerate(train_loader):
system_memory = (
0.99 * system_memory
+ 0.01 * dict(psutil.virtual_memory()._asdict())["available"] / 2**30
)
cpid_memory = (
0.99 * cpid_memory + 0.01 * current_process.memory_info()[0] / 2.0**30
)
gpu_mem = (
0.99 * gpu_mem
+ 0.01
* (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0])
/ 2**30
)
if j % 1000 == 0 and is_main_process():
print(f"System information at step {j}")
print(f"System-wide RAM available: {system_memory:.2f}GB")
print(f"CPU utilization: {psutil.cpu_percent(interval=None)}%")
print(f"GPU memory utilized: {gpu_mem:.2f}GB")
batches["data"] = {
k: v.to(model.device, non_blocking=True)
for k, v in batches["data"].items()
}
for idx in range(nsteps_accumulation_gradient):
batch = {}
batch_slice = slice(idx * batch_chunk, (idx + 1) * batch_chunk)
batch["data"] = {k: v[batch_slice] for k, v in batches["data"].items()}
batch["img_metas"] = batches["img_metas"][batch_slice]
with (
model.no_sync()
if idx < nsteps_accumulation_gradient - 1
else nullcontext()
):
with context:
preds, losses = model(batch["data"], batch["img_metas"])
loss = sum(losses["opt"].values())
scaler.scale(loss).backward()
losses_dict = {
k: v.detach() for loss in losses.values() for k, v in loss.items()
}
track_losses.update(
{
k: track_losses.get(k, 0.0)
+ torch.nan_to_num(v, nan=1e5, posinf=1e5, neginf=1e5)
for k, v in losses_dict.items()
}
)
ddp_model.loss_history = track_losses
if clipping is not None:
scaler.unscale_(optimizer)
grad_norm = clipping_fn(ddp_model.parameters_grad(), clipping)
if torch.isfinite(grad_norm):
track_losses.update(
{"Grad_Norm": track_losses.get("Grad_Norm", 0.0) + grad_norm}
)
# there is a deeper issue, either log/sqrt of negative loss
# or the inputs create large values and destroy model weights
if is_16bit and scaler.get_scale() < 1:
raise ValueError("Scale went less than 1, ISSUE!!!")
scaler.step(optimizer)
scaler.update()
scheduler_wd.step()
scheduler_lr.step()
scheduler_betas.step()
model.module.step()
optimizer.zero_grad(set_to_none=True)
if step % EMA_INTERVAL == 0:
ema_handle.update()
if is_main_process() and track_pbar:
pbar.update(1)
step += 1
# LOGGING
if step % 100 == 0 and is_main_process():
log_num = min(10, preds["depth"].shape[0])
log_train_artifacts(
batch["data"]["image"][-log_num:, 0].float(),
(
batch["data"]["depth"][-log_num:, 0].float()
if "depth" in batch["data"]
else []
),
preds["depth"][-log_num:, 0].detach().float(),
infos={
k: v[-log_num:, 0] for k, v in preds.get("infos", {}).items()
},
step=step,
)
if step % 50 == 0:
track_losses = {
k: v / (50 * nsteps_accumulation_gradient)
for k, v in track_losses.items()
}
# grad norm is for every step!
track_losses["Grad_Norm"] = (
track_losses["Grad_Norm"] * nsteps_accumulation_gradient
)
track_losses = aggregate_sync_losses(track_losses, device=model.device)
if is_main_process():
elapsed = int(time() - start)
eta = int(elapsed * (n_steps - step) / max(1, step - init_steps))
print(
f"Step {step}/{n_steps} [{format_seconds(elapsed)}<{format_seconds(eta)}]"
)
try:
wandb.log(
{
**{f"Train/{k}": v for k, v in track_losses.items()},
**{f"Train/lr": scheduler_lr.get()[-1]},
**{f"Train/wd": scheduler_wd.get()[-2]},
**{f"Train/scale_f16": log2(scaler.get_scale())},
},
step=step,
)
except Exception as e:
print("Not logging loss because of:", e)
if step % 100 == 0:
log_loss_dict = {
f"Train/{k}": v for k, v in track_losses.items()
}
print(
", ".join(
[f"{k}: {v:.5f}" for k, v in log_loss_dict.items()]
)
)
track_losses = {} # reinit every 50 steps, average the current 50 steps
# Validation
is_last_step = step >= config["training"]["n_iters"]
is_validation = step % config["training"]["validation_interval"] == 0
if is_last_step or is_validation:
torch.cuda.empty_cache()
barrier()
if is_main_process():
print(f"Validation at {step}th step...")
ddp_model.eval()
start_validation = time()
with torch.no_grad(), ema_handle.average_parameters():
validate(
model,
test_loaders=val_loaders,
step=step,
run_id=run_id,
idxs=(64, 96, 224, 256), # random
context=context_val,
)
if is_main_process():
print(f"Elapsed: {format_seconds(int(time() - start_validation))}")
ddp_model.train()
torch.cuda.empty_cache()
if step >= config["training"]["n_iters"]:
if is_main_process() and track_pbar:
pbar.close()
wandb.finish(0)
dist.destroy_process_group()
return 0
if __name__ == "__main__":
if "SLURM_PROCID" in os.environ:
os.environ["TRITON_CACHE_DIR"] = "/tmp"
# Arguments
parser = argparse.ArgumentParser(
description="Training script", conflict_handler="resolve"
)
parser.add_argument("--config-file", type=str, required=True)
parser.add_argument("--master-port", type=str)
parser.add_argument("--distributed", action="store_true")
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
with open(args.config_file, "r") as f:
config = json.load(f)
deterministic = config["generic"].get("deterministic", True)
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.set_num_threads(1)
main_worker(config, args)
|