Spaces:
Sleeping
Sleeping
# Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# -------------------------------------------------------- | |
# Main function for training one epoch or testing | |
# -------------------------------------------------------- | |
import math | |
import sys | |
from typing import Iterable | |
import numpy as np | |
import torch | |
import torchvision | |
from utils import misc as misc | |
def split_prediction_conf(predictions, with_conf=False): | |
if not with_conf: | |
return predictions, None | |
conf = predictions[:,-1:,:,:] | |
predictions = predictions[:,:-1,:,:] | |
return predictions, conf | |
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module, | |
data_loader: Iterable, optimizer: torch.optim.Optimizer, | |
device: torch.device, epoch: int, loss_scaler, | |
log_writer=None, print_freq = 20, | |
args=None): | |
model.train(True) | |
metric_logger = misc.MetricLogger(delimiter=" ") | |
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) | |
header = 'Epoch: [{}]'.format(epoch) | |
accum_iter = args.accum_iter | |
optimizer.zero_grad() | |
details = {} | |
if log_writer is not None: | |
print('log_dir: {}'.format(log_writer.log_dir)) | |
if args.img_per_epoch: | |
iter_per_epoch = args.img_per_epoch // args.batch_size + int(args.img_per_epoch % args.batch_size > 0) | |
assert len(data_loader) >= iter_per_epoch, 'Dataset is too small for so many iterations' | |
len_data_loader = iter_per_epoch | |
else: | |
len_data_loader, iter_per_epoch = len(data_loader), None | |
for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_logger.log_every(data_loader, print_freq, header, max_iter=iter_per_epoch)): | |
image1 = image1.to(device, non_blocking=True) | |
image2 = image2.to(device, non_blocking=True) | |
gt = gt.to(device, non_blocking=True) | |
# we use a per iteration (instead of per epoch) lr scheduler | |
if data_iter_step % accum_iter == 0: | |
misc.adjust_learning_rate(optimizer, data_iter_step / len_data_loader + epoch, args) | |
with torch.cuda.amp.autocast(enabled=bool(args.amp)): | |
prediction = model(image1, image2) | |
prediction, conf = split_prediction_conf(prediction, criterion.with_conf) | |
batch_metrics = metrics(prediction.detach(), gt) | |
loss = criterion(prediction, gt) if conf is None else criterion(prediction, gt, conf) | |
loss_value = loss.item() | |
if not math.isfinite(loss_value): | |
print("Loss is {}, stopping training".format(loss_value)) | |
sys.exit(1) | |
loss /= accum_iter | |
loss_scaler(loss, optimizer, parameters=model.parameters(), | |
update_grad=(data_iter_step + 1) % accum_iter == 0) | |
if (data_iter_step + 1) % accum_iter == 0: | |
optimizer.zero_grad() | |
torch.cuda.synchronize() | |
metric_logger.update(loss=loss_value) | |
for k,v in batch_metrics.items(): | |
metric_logger.update(**{k: v.item()}) | |
lr = optimizer.param_groups[0]["lr"] | |
metric_logger.update(lr=lr) | |
#if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value) | |
time_to_log = ((data_iter_step + 1) % (args.tboard_log_step * accum_iter) == 0 or data_iter_step == len_data_loader-1) | |
loss_value_reduce = misc.all_reduce_mean(loss_value) | |
if log_writer is not None and time_to_log: | |
epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) | |
# We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. | |
log_writer.add_scalar('train/loss', loss_value_reduce, epoch_1000x) | |
log_writer.add_scalar('lr', lr, epoch_1000x) | |
for k,v in batch_metrics.items(): | |
log_writer.add_scalar('train/'+k, v.item(), epoch_1000x) | |
# gather the stats from all processes | |
#if args.distributed: metric_logger.synchronize_between_processes() | |
print("Averaged stats:", metric_logger) | |
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} | |
def validate_one_epoch(model: torch.nn.Module, | |
criterion: torch.nn.Module, | |
metrics: torch.nn.Module, | |
data_loaders: list[Iterable], | |
device: torch.device, | |
epoch: int, | |
log_writer=None, | |
args=None): | |
model.eval() | |
metric_loggers = [] | |
header = 'Epoch: [{}]'.format(epoch) | |
print_freq = 20 | |
conf_mode = args.tile_conf_mode | |
crop = args.crop | |
if log_writer is not None: | |
print('log_dir: {}'.format(log_writer.log_dir)) | |
results = {} | |
dnames = [] | |
image1, image2, gt, prediction = None, None, None, None | |
for didx, data_loader in enumerate(data_loaders): | |
dname = str(data_loader.dataset) | |
dnames.append(dname) | |
metric_loggers.append(misc.MetricLogger(delimiter=" ")) | |
for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_loggers[didx].log_every(data_loader, print_freq, header)): | |
image1 = image1.to(device, non_blocking=True) | |
image2 = image2.to(device, non_blocking=True) | |
gt = gt.to(device, non_blocking=True) | |
if dname.startswith('Spring'): | |
assert gt.size(2)==image1.size(2)*2 and gt.size(3)==image1.size(3)*2 | |
gt = (gt[:,:,0::2,0::2] + gt[:,:,0::2,1::2] + gt[:,:,1::2,0::2] + gt[:,:,1::2,1::2] ) / 4.0 # we approximate the gt based on the 2x upsampled ones | |
with torch.inference_mode(): | |
prediction, tiled_loss, c = tiled_pred(model, criterion, image1, image2, gt, conf_mode=conf_mode, overlap=args.val_overlap, crop=crop, with_conf=criterion.with_conf) | |
batch_metrics = metrics(prediction.detach(), gt) | |
loss = criterion(prediction.detach(), gt) if not criterion.with_conf else criterion(prediction.detach(), gt, c) | |
loss_value = loss.item() | |
metric_loggers[didx].update(loss_tiled=tiled_loss.item()) | |
metric_loggers[didx].update(**{f'loss': loss_value}) | |
for k,v in batch_metrics.items(): | |
metric_loggers[didx].update(**{dname+'_' + k: v.item()}) | |
results = {k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items()} | |
if len(dnames)>1: | |
for k in batch_metrics.keys(): | |
results['AVG_'+k] = sum(results[dname+'_'+k] for dname in dnames) / len(dnames) | |
if log_writer is not None : | |
epoch_1000x = int((1 + epoch) * 1000) | |
for k,v in results.items(): | |
log_writer.add_scalar('val/'+k, v, epoch_1000x) | |
print("Averaged stats:", results) | |
return results | |
import torch.nn.functional as F | |
def _resize_img(img, new_size): | |
return F.interpolate(img, size=new_size, mode='bicubic', align_corners=False) | |
def _resize_stereo_or_flow(data, new_size): | |
assert data.ndim==4 | |
assert data.size(1) in [1,2] | |
scale_x = new_size[1]/float(data.size(3)) | |
out = F.interpolate(data, size=new_size, mode='bicubic', align_corners=False) | |
out[:,0,:,:] *= scale_x | |
if out.size(1)==2: | |
scale_y = new_size[0]/float(data.size(2)) | |
out[:,1,:,:] *= scale_y | |
print(scale_x, new_size, data.shape) | |
return out | |
def tiled_pred(model, criterion, img1, img2, gt, | |
overlap=0.5, bad_crop_thr=0.05, | |
downscale=False, crop=512, ret='loss', | |
conf_mode='conf_expsigmoid_10_5', with_conf=False, | |
return_time=False): | |
# for each image, we are going to run inference on many overlapping patches | |
# then, all predictions will be weighted-averaged | |
if gt is not None: | |
B, C, H, W = gt.shape | |
else: | |
B, _, H, W = img1.shape | |
C = model.head.num_channels-int(with_conf) | |
win_height, win_width = crop[0], crop[1] | |
# upscale to be larger than the crop | |
do_change_scale = H<win_height or W<win_width | |
if do_change_scale: | |
upscale_factor = max(win_width/W, win_height/W) | |
original_size = (H,W) | |
new_size = (round(H*upscale_factor),round(W*upscale_factor)) | |
img1 = _resize_img(img1, new_size) | |
img2 = _resize_img(img2, new_size) | |
# resize gt just for the computation of tiled losses | |
if gt is not None: gt = _resize_stereo_or_flow(gt, new_size) | |
H,W = img1.shape[2:4] | |
if conf_mode.startswith('conf_expsigmoid_'): # conf_expsigmoid_30_10 | |
beta, betasigmoid = map(float, conf_mode[len('conf_expsigmoid_'):].split('_')) | |
elif conf_mode.startswith('conf_expbeta'): # conf_expbeta3 | |
beta = float(conf_mode[len('conf_expbeta'):]) | |
else: | |
raise NotImplementedError(f"conf_mode {conf_mode} is not implemented") | |
def crop_generator(): | |
for sy in _overlapping(H, win_height, overlap): | |
for sx in _overlapping(W, win_width, overlap): | |
yield sy, sx, sy, sx, True | |
# keep track of weighted sum of prediction*weights and weights | |
accu_pred = img1.new_zeros((B, C, H, W)) # accumulate the weighted sum of predictions | |
accu_conf = img1.new_zeros((B, H, W)) + 1e-16 # accumulate the weights | |
accu_c = img1.new_zeros((B, H, W)) # accumulate the weighted sum of confidences ; not so useful except for computing some losses | |
tiled_losses = [] | |
if return_time: | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
start.record() | |
for sy1, sx1, sy2, sx2, aligned in crop_generator(): | |
# compute optical flow there | |
pred = model(_crop(img1,sy1,sx1), _crop(img2,sy2,sx2)) | |
pred, predconf = split_prediction_conf(pred, with_conf=with_conf) | |
if gt is not None: gtcrop = _crop(gt,sy1,sx1) | |
if criterion is not None and gt is not None: | |
tiled_losses.append( criterion(pred, gtcrop).item() if predconf is None else criterion(pred, gtcrop, predconf).item() ) | |
if conf_mode.startswith('conf_expsigmoid_'): | |
conf = torch.exp(- beta * 2 * (torch.sigmoid(predconf / betasigmoid) - 0.5)).view(B,win_height,win_width) | |
elif conf_mode.startswith('conf_expbeta'): | |
conf = torch.exp(- beta * predconf).view(B,win_height,win_width) | |
else: | |
raise NotImplementedError | |
accu_pred[...,sy1,sx1] += pred * conf[:,None,:,:] | |
accu_conf[...,sy1,sx1] += conf | |
accu_c[...,sy1,sx1] += predconf.view(B,win_height,win_width) * conf | |
pred = accu_pred / accu_conf[:, None,:,:] | |
c = accu_c / accu_conf | |
assert not torch.any(torch.isnan(pred)) | |
if return_time: | |
end.record() | |
torch.cuda.synchronize() | |
time = start.elapsed_time(end)/1000.0 # this was in milliseconds | |
if do_change_scale: | |
pred = _resize_stereo_or_flow(pred, original_size) | |
if return_time: | |
return pred, torch.mean(torch.tensor(tiled_losses)), c, time | |
return pred, torch.mean(torch.tensor(tiled_losses)), c | |
def _overlapping(total, window, overlap=0.5): | |
assert total >= window and 0 <= overlap < 1, (total, window, overlap) | |
num_windows = 1 + int(np.ceil( (total - window) / ((1-overlap) * window) )) | |
offsets = np.linspace(0, total-window, num_windows).round().astype(int) | |
yield from (slice(x, x+window) for x in offsets) | |
def _crop(img, sy, sx): | |
B, THREE, H, W = img.shape | |
if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: | |
return img[:,:,sy,sx] | |
l, r = max(0,-sx.start), max(0,sx.stop-W) | |
t, b = max(0,-sy.start), max(0,sy.stop-H) | |
img = torch.nn.functional.pad(img, (l,r,t,b), mode='constant') | |
return img[:, :, slice(sy.start+t,sy.stop+t), slice(sx.start+l,sx.stop+l)] |