depth / metric_depth /train.py
LiheYoung's picture
Add Github repository content
2680cbd verified
raw
history blame
9.4 kB
import argparse
import logging
import os
import pprint
import random
import warnings
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from dataset.hypersim import Hypersim
from dataset.kitti import KITTI
from dataset.vkitti2 import VKITTI2
from depth_anything_v2.dpt import DepthAnythingV2
from util.dist_helper import setup_distributed
from util.loss import SiLogLoss
from util.metric import eval_depth
from util.utils import init_log
parser = argparse.ArgumentParser(description='Depth Anything V2 for Metric Depth Estimation')
parser.add_argument('--encoder', default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
parser.add_argument('--dataset', default='hypersim', choices=['hypersim', 'vkitti'])
parser.add_argument('--img-size', default=518, type=int)
parser.add_argument('--min-depth', default=0.001, type=float)
parser.add_argument('--max-depth', default=20, type=float)
parser.add_argument('--epochs', default=40, type=int)
parser.add_argument('--bs', default=2, type=int)
parser.add_argument('--lr', default=0.000005, type=float)
parser.add_argument('--pretrained-from', type=str)
parser.add_argument('--save-path', type=str, required=True)
parser.add_argument('--local-rank', default=0, type=int)
parser.add_argument('--port', default=None, type=int)
def main():
args = parser.parse_args()
warnings.simplefilter('ignore', np.RankWarning)
logger = init_log('global', logging.INFO)
logger.propagate = 0
rank, world_size = setup_distributed(port=args.port)
if rank == 0:
all_args = {**vars(args), 'ngpus': world_size}
logger.info('{}\n'.format(pprint.pformat(all_args)))
writer = SummaryWriter(args.save_path)
cudnn.enabled = True
cudnn.benchmark = True
size = (args.img_size, args.img_size)
if args.dataset == 'hypersim':
trainset = Hypersim('dataset/splits/hypersim/train.txt', 'train', size=size)
elif args.dataset == 'vkitti':
trainset = VKITTI2('dataset/splits/vkitti2/train.txt', 'train', size=size)
else:
raise NotImplementedError
trainsampler = torch.utils.data.distributed.DistributedSampler(trainset)
trainloader = DataLoader(trainset, batch_size=args.bs, pin_memory=True, num_workers=4, drop_last=True, sampler=trainsampler)
if args.dataset == 'hypersim':
valset = Hypersim('dataset/splits/hypersim/val.txt', 'val', size=size)
elif args.dataset == 'vkitti':
valset = KITTI('dataset/splits/kitti/val.txt', 'val', size=size)
else:
raise NotImplementedError
valsampler = torch.utils.data.distributed.DistributedSampler(valset)
valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=4, drop_last=True, sampler=valsampler)
local_rank = int(os.environ["LOCAL_RANK"])
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
model = DepthAnythingV2(**{**model_configs[args.encoder], 'max_depth': args.max_depth})
if args.pretrained_from:
model.load_state_dict({k: v for k, v in torch.load(args.pretrained_from, map_location='cpu').items() if 'pretrained' in k}, strict=False)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda(local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False,
output_device=local_rank, find_unused_parameters=True)
criterion = SiLogLoss().cuda(local_rank)
optimizer = AdamW([{'params': [param for name, param in model.named_parameters() if 'pretrained' in name], 'lr': args.lr},
{'params': [param for name, param in model.named_parameters() if 'pretrained' not in name], 'lr': args.lr * 10.0}],
lr=args.lr, betas=(0.9, 0.999), weight_decay=0.01)
total_iters = args.epochs * len(trainloader)
previous_best = {'d1': 0, 'd2': 0, 'd3': 0, 'abs_rel': 100, 'sq_rel': 100, 'rmse': 100, 'rmse_log': 100, 'log10': 100, 'silog': 100}
for epoch in range(args.epochs):
if rank == 0:
logger.info('===========> Epoch: {:}/{:}, d1: {:.3f}, d2: {:.3f}, d3: {:.3f}'.format(epoch, args.epochs, previous_best['d1'], previous_best['d2'], previous_best['d3']))
logger.info('===========> Epoch: {:}/{:}, abs_rel: {:.3f}, sq_rel: {:.3f}, rmse: {:.3f}, rmse_log: {:.3f}, '
'log10: {:.3f}, silog: {:.3f}'.format(
epoch, args.epochs, previous_best['abs_rel'], previous_best['sq_rel'], previous_best['rmse'],
previous_best['rmse_log'], previous_best['log10'], previous_best['silog']))
trainloader.sampler.set_epoch(epoch + 1)
model.train()
total_loss = 0
for i, sample in enumerate(trainloader):
optimizer.zero_grad()
img, depth, valid_mask = sample['image'].cuda(), sample['depth'].cuda(), sample['valid_mask'].cuda()
if random.random() < 0.5:
img = img.flip(-1)
depth = depth.flip(-1)
valid_mask = valid_mask.flip(-1)
pred = model(img)
loss = criterion(pred, depth, (valid_mask == 1) & (depth >= args.min_depth) & (depth <= args.max_depth))
loss.backward()
optimizer.step()
total_loss += loss.item()
iters = epoch * len(trainloader) + i
lr = args.lr * (1 - iters / total_iters) ** 0.9
optimizer.param_groups[0]["lr"] = lr
optimizer.param_groups[1]["lr"] = lr * 10.0
if rank == 0:
writer.add_scalar('train/loss', loss.item(), iters)
if rank == 0 and i % 100 == 0:
logger.info('Iter: {}/{}, LR: {:.7f}, Loss: {:.3f}'.format(i, len(trainloader), optimizer.param_groups[0]['lr'], loss.item()))
model.eval()
results = {'d1': torch.tensor([0.0]).cuda(), 'd2': torch.tensor([0.0]).cuda(), 'd3': torch.tensor([0.0]).cuda(),
'abs_rel': torch.tensor([0.0]).cuda(), 'sq_rel': torch.tensor([0.0]).cuda(), 'rmse': torch.tensor([0.0]).cuda(),
'rmse_log': torch.tensor([0.0]).cuda(), 'log10': torch.tensor([0.0]).cuda(), 'silog': torch.tensor([0.0]).cuda()}
nsamples = torch.tensor([0.0]).cuda()
for i, sample in enumerate(valloader):
img, depth, valid_mask = sample['image'].cuda().float(), sample['depth'].cuda()[0], sample['valid_mask'].cuda()[0]
with torch.no_grad():
pred = model(img)
pred = F.interpolate(pred[:, None], depth.shape[-2:], mode='bilinear', align_corners=True)[0, 0]
valid_mask = (valid_mask == 1) & (depth >= args.min_depth) & (depth <= args.max_depth)
if valid_mask.sum() < 10:
continue
cur_results = eval_depth(pred[valid_mask], depth[valid_mask])
for k in results.keys():
results[k] += cur_results[k]
nsamples += 1
torch.distributed.barrier()
for k in results.keys():
dist.reduce(results[k], dst=0)
dist.reduce(nsamples, dst=0)
if rank == 0:
logger.info('==========================================================================================')
logger.info('{:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}'.format(*tuple(results.keys())))
logger.info('{:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}'.format(*tuple([(v / nsamples).item() for v in results.values()])))
logger.info('==========================================================================================')
print()
for name, metric in results.items():
writer.add_scalar(f'eval/{name}', (metric / nsamples).item(), epoch)
for k in results.keys():
if k in ['d1', 'd2', 'd3']:
previous_best[k] = max(previous_best[k], (results[k] / nsamples).item())
else:
previous_best[k] = min(previous_best[k], (results[k] / nsamples).item())
if rank == 0:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'previous_best': previous_best,
}
torch.save(checkpoint, os.path.join(args.save_path, 'latest.pth'))
if __name__ == '__main__':
main()